diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 855eb5bf77f16..f52d785e05cdd 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"),
License: Apache License (== 2.0)
URL: http://www.apache.org/ http://spark.apache.org/
BugReports: http://spark.apache.org/contributing.html
+SystemRequirements: Java (== 8)
Depends:
R (>= 3.0),
methods
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5f8209689a559..c575fe255f57a 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -352,6 +352,7 @@ exportMethods("%<=>%",
"sinh",
"size",
"skewness",
+ "slice",
"sort_array",
"soundex",
"spark_partition_id",
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 7244cc9f9e38e..4c87f64e7f0e1 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -60,6 +60,41 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack
combinedArgs
}
+checkJavaVersion <- function() {
+ javaBin <- "java"
+ javaHome <- Sys.getenv("JAVA_HOME")
+ javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements"))
+ sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L))
+ if (javaHome != "") {
+ javaBin <- file.path(javaHome, "bin", javaBin)
+ }
+
+ # If java is missing from PATH, we get an error in Unix and a warning in Windows
+ javaVersionOut <- tryCatch(
+ launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE),
+ error = function(e) {
+ stop("Java version check failed. Please make sure Java is installed",
+ " and set JAVA_HOME to point to the installation directory.", e)
+ },
+ warning = function(w) {
+ stop("Java version check failed. Please make sure Java is installed",
+ " and set JAVA_HOME to point to the installation directory.", w)
+ })
+ javaVersionFilter <- Filter(
+ function(x) {
+ grepl(" version", x)
+ }, javaVersionOut)
+
+ javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2]
+ # javaVersionStr is of the form 1.8.0_92.
+ # Extract 8 from it to compare to sparkJavaVersion
+ javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2])
+ if (javaVersionNum != sparkJavaVersion) {
+ stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:",
+ javaVersionStr))
+ }
+}
+
launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
sparkSubmitBinName <- determineSparkSubmitBin()
if (sparkHome != "") {
@@ -67,6 +102,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) {
} else {
sparkSubmitBin <- sparkSubmitBinName
}
+
combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages)
cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n")
invisible(launchScript(sparkSubmitBin, combinedArgs))
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 1f97054443e1b..fcb3521f901ea 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -208,16 +208,20 @@ NULL
#' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1)))
#' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1)))
#' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1)))
-#' head(select(tmp, flatten(tmp$v1)))
+#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1)))
#' tmp2 <- mutate(tmp, v2 = explode(tmp$v1))
#' head(tmp2)
#' head(select(tmp, posexplode(tmp$v1)))
+#' head(select(tmp, slice(tmp$v1, 2L, 2L)))
#' head(select(tmp, sort_array(tmp$v1)))
#' head(select(tmp, sort_array(tmp$v1, asc = FALSE)))
#' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl))
#' head(select(tmp3, map_keys(tmp3$v3)))
#' head(select(tmp3, map_values(tmp3$v3)))
-#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))}
+#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
+#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp))
+#' head(select(tmp4, concat(tmp4$v4, tmp4$v5)))
+#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))}
NULL
#' Window functions for Column operations
@@ -1259,9 +1263,9 @@ setMethod("quarter",
})
#' @details
-#' \code{reverse}: Reverses the string column and returns it as a new string column.
+#' \code{reverse}: Returns a reversed string or an array with reverse order of elements.
#'
-#' @rdname column_string_functions
+#' @rdname column_collection_functions
#' @aliases reverse reverse,Column-method
#' @note reverse since 1.5.0
setMethod("reverse",
@@ -1912,6 +1916,7 @@ setMethod("atan2", signature(y = "Column"),
#' @details
#' \code{datediff}: Returns the number of days from \code{y} to \code{x}.
+#' If \code{y} is later than \code{x} then the result is positive.
#'
#' @rdname column_datetime_diff_functions
#' @aliases datediff datediff,Column-method
@@ -1971,7 +1976,10 @@ setMethod("levenshtein", signature(y = "Column"),
})
#' @details
-#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}.
+#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}.
+#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x}
+#' are on the same day of month, or both are the last day of month, time of day will be ignored.
+#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits.
#'
#' @rdname column_datetime_diff_functions
#' @aliases months_between months_between,Column-method
@@ -2050,20 +2058,10 @@ setMethod("countDistinct",
#' @details
#' \code{concat}: Concatenates multiple input columns together into a single column.
-#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
+#' The function works with strings, binary and compatible array columns.
#'
-#' @rdname column_string_functions
+#' @rdname column_collection_functions
#' @aliases concat concat,Column-method
-#' @examples
-#'
-#' \dontrun{
-#' # concatenate strings
-#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex),
-#' s2 = concat(df$Class, df$Sex, df$Age),
-#' s3 = concat(df$Class, df$Sex, df$Age, df$Class),
-#' s4 = concat_ws("_", df$Class, df$Sex),
-#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived))
-#' head(tmp)}
#' @note concat since 1.5.0
setMethod("concat",
signature(x = "Column"),
@@ -2404,6 +2402,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"),
#' @param sep separator to use.
#' @rdname column_string_functions
#' @aliases concat_ws concat_ws,character,Column-method
+#' @examples
+#'
+#' \dontrun{
+#' # concatenate strings
+#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex),
+#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived))
+#' head(tmp)}
#' @note concat_ws since 1.5.0
setMethod("concat_ws", signature(sep = "character", x = "Column"),
function(sep, x, ...) {
@@ -3058,7 +3063,8 @@ setMethod("array_sort",
})
#' @details
-#' \code{flatten}: Transforms an array of arrays into a single array.
+#' \code{flatten}: Creates a single array from an array of arrays.
+#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.
#'
#' @rdname column_collection_functions
#' @aliases flatten flatten,Column-method
@@ -3138,6 +3144,22 @@ setMethod("size",
column(jc)
})
+#' @details
+#' \code{slice}: Returns an array containing all the elements in x from the index start
+#' (or starting from the end if start is negative) with the specified length.
+#'
+#' @rdname column_collection_functions
+#' @param start an index indicating the first element occuring in the result.
+#' @param length a number of consecutive elements choosen to the result.
+#' @aliases slice slice,Column-method
+#' @note slice since 2.4.0
+setMethod("slice",
+ signature(x = "Column"),
+ function(x, start, length) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length)
+ column(jc)
+ })
+
#' @details
#' \code{sort_array}: Sorts the input array in ascending or descending order according to
#' the natural ordering of the array elements. NA elements will be placed at the beginning of
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 5faa51eef3abd..3ea181157b644 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -624,7 +624,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") })
#' @rdname summary
setGeneric("summary", function(object, ...) { standardGeneric("summary") })
-setGeneric("toJSON", function(x) { standardGeneric("toJSON") })
+setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") })
setGeneric("toRDD", function(x) { standardGeneric("toRDD") })
@@ -817,7 +817,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") })
#' @rdname column
setGeneric("column", function(x) { standardGeneric("column") })
-#' @rdname column_string_functions
+#' @rdname column_collection_functions
#' @name NULL
setGeneric("concat", function(x, ...) { standardGeneric("concat") })
@@ -1134,7 +1134,7 @@ setGeneric("regexp_replace",
#' @name NULL
setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") })
-#' @rdname column_string_functions
+#' @rdname column_collection_functions
#' @name NULL
setGeneric("reverse", function(x) { standardGeneric("reverse") })
@@ -1194,6 +1194,10 @@ setGeneric("size", function(x) { standardGeneric("size") })
#' @name NULL
setGeneric("skewness", function(x) { standardGeneric("skewness") })
+#' @rdname column_collection_functions
+#' @name NULL
+setGeneric("slice", function(x, start, length) { standardGeneric("slice") })
+
#' @rdname column_collection_functions
#' @name NULL
setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") })
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 38ee79477996f..f7c1663d32c96 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -167,6 +167,7 @@ sparkR.sparkContext <- function(
submitOps <- getClientModeSparkSubmitOpts(
Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"),
sparkEnvirMap)
+ checkJavaVersion()
launchBackend(
args = path,
sparkHome = sparkHome,
@@ -193,7 +194,7 @@ sparkR.sparkContext <- function(
# Don't use readString() so that we can provide a useful
# error message if the R and Java versions are mismatched.
- authSecretLen = readInt(f)
+ authSecretLen <- readInt(f)
if (length(authSecretLen) == 0 || authSecretLen == 0) {
stop("Unexpected EOF in JVM connection data. Mismatched versions?")
}
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index f1b5ecaa017df..c3501977e64bc 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -746,7 +746,7 @@ varargsToJProperties <- function(...) {
props
}
-launchScript <- function(script, combinedArgs, wait = FALSE) {
+launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") {
if (.Platform$OS.type == "windows") {
scriptWithArgs <- paste(script, combinedArgs, sep = " ")
# on Windows, intern = F seems to mean output to the console. (documentation on this is missing)
@@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) {
# stdout = F means discard output
# stdout = "" means to its console (default)
# Note that the console of this child process might not be the same as the running R process.
- system2(script, combinedArgs, stdout = "", wait = wait)
+ system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr)
}
}
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index b8bfded0ebf2d..13b55ac6e6e3c 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -1479,7 +1479,7 @@ test_that("column functions", {
df5 <- createDataFrame(list(list(a = "010101")))
expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15")
- # Test array_contains(), array_max(), array_min(), array_position() and element_at()
+ # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse()
df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L))))
result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]]
expect_equal(result, c(TRUE, FALSE))
@@ -1496,6 +1496,13 @@ test_that("column functions", {
result <- collect(select(df, element_at(df[[1]], 1L)))[[1]]
expect_equal(result, c(1, 6))
+ result <- collect(select(df, reverse(df[[1]])))[[1]]
+ expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L)))
+
+ df2 <- createDataFrame(list(list("abc")))
+ result <- collect(select(df2, reverse(df2[[1]])))[[1]]
+ expect_equal(result, "cba")
+
# Test array_sort() and sort_array()
df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L))))
@@ -1507,7 +1514,18 @@ test_that("column functions", {
result <- collect(select(df, sort_array(df[[1]])))[[1]]
expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L)))
- # Test flattern
+ # Test slice()
+ df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L))))
+ result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]]
+ expect_equal(result, list(list(2L, 3L), list(5L)))
+
+ # Test concat()
+ df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)),
+ list(list(7L, 8L, 9L), list(10L, 11L, 12L))))
+ result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]]
+ expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L)))
+
+ # Test flatten()
df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))),
list(list(list(5L, 6L), list(7L, 8L)))))
result <- collect(select(df, flatten(df[[1]])))[[1]]
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index cccd3ea457ba4..0791fe856ef15 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -178,4 +178,6 @@ private[spark] class TaskContextImpl(
private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
+ // TODO: shall we publish it and define it in `TaskContext`?
+ private[spark] def getLocalProperties(): Properties = localProperties
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index a76283e33fa65..33901bc8380e9 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason {
* Task was killed intentionally and needs to be rescheduled.
*/
@DeveloperApi
-case class TaskKilled(reason: String) extends TaskFailedReason {
+case class TaskKilled(
+ reason: String,
+ accumUpdates: Seq[AccumulableInfo] = Seq.empty,
+ private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil)
+ extends TaskFailedReason {
+
override def toErrorString: String = s"TaskKilled ($reason)"
override def countTowardsTaskFailures: Boolean = false
+
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 087e9c31a9c9a..4baf032f0e9c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -310,6 +310,7 @@ private[spark] class SparkSubmit extends Logging {
val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER
val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER
val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER
+ val isMesosClient = clusterManager == MESOS && deployMode == CLIENT
if (!isMesosCluster && !isStandAloneCluster) {
// Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
@@ -337,7 +338,7 @@ private[spark] class SparkSubmit extends Logging {
val targetDir = Utils.createTempDir()
// assure a keytab is available from any place in a JVM
- if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) {
+ if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) {
if (args.principal != null) {
if (args.keytab != null) {
require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist")
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 0733fdb72cafb..fed4e0a5069c3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.util.Utils
-
/**
* Parses and encapsulates arguments from the spark-submit script.
* The env argument is used for testing.
@@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var proxyUser: String = null
var principal: String = null
var keytab: String = null
+ private var dynamicAllocationEnabled: Boolean = false
// Standalone cluster mode only
var supervise: Boolean = false
@@ -198,6 +198,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull
keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull
principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull
+ dynamicAllocationEnabled =
+ sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase)
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && !isR && primaryResource != null) {
@@ -274,7 +276,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) {
error("Total executor cores must be a positive number")
}
- if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) {
+ if (!dynamicAllocationEnabled &&
+ numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) {
error("Number of executors must be a positive number")
}
if (pyFiles != null && !isPython) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index 6fc12d721e6f1..32667ddf5c7ea 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("")
val lastUpdatedTime = parent.getLastUpdatedTime()
val providerConfig = parent.getProviderConfig()
val content =
- ++
-
+ ++
+
Application {appId} not found.
res.setStatus(HttpServletResponse.SC_NOT_FOUND)
- UIUtils.basicSparkPage(msg, "Not Found").foreach { n =>
+ UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n =>
res.getWriter().write(n.toString)
}
return
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f699c75085fe1..fad4e46dc035d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
.getOrElse(state.completedApps.find(_.id == appId).orNull)
if (app == null) {
val msg =
fetchsize |
@@ -1814,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see
- In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround.
- In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files.
- Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior.
+ - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`.
## Upgrading From Spark SQL 2.2 to 2.3
@@ -2237,7 +2249,7 @@ referencing a singleton.
Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs.
Currently, Hive SerDes and UDFs are based on Hive 1.2.1,
and Spark SQL can be connected to different versions of Hive Metastore
-(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)).
+(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)).
#### Deploying in Existing Hive Warehouses
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
index 88abf8a8dd027..badaa69cc303c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -106,7 +106,7 @@ class KafkaContinuousReader(
startOffsets.toSeq.map {
case (topicPartition, start) =>
- KafkaContinuousDataReaderFactory(
+ KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
.asInstanceOf[InputPartition[UnsafeRow]]
}.asJava
@@ -146,7 +146,7 @@ class KafkaContinuousReader(
}
/**
- * A data reader factory for continuous Kafka processing. This will be serialized and transformed
+ * An input partition for continuous Kafka processing. This will be serialized and transformed
* into a full reader on executors.
*
* @param topicPartition The (topic, partition) pair this task is responsible for.
@@ -156,7 +156,7 @@ class KafkaContinuousReader(
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
*/
-case class KafkaContinuousDataReaderFactory(
+case class KafkaContinuousInputPartition(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
index 48508d057a540..941f0ab177e48 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala
@@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging {
// likely running on a beefy machine that can handle a large number of simultaneously
// active consumers.
- if (entry.getValue.inUse == false && this.size > capacity) {
+ if (!entry.getValue.inUse && this.size > capacity) {
logWarning(
s"KafkaConsumer cache hitting max capacity of $capacity, " +
s"removing consumer for ${entry.getKey}")
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
index 8a377738ea782..64ba98762788c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala
@@ -143,7 +143,7 @@ private[kafka010] class KafkaMicroBatchReader(
// Generate factories based on the offset ranges
val factories = offsetRanges.map { range =>
- new KafkaMicroBatchDataReaderFactory(
+ new KafkaMicroBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava
@@ -300,7 +300,7 @@ private[kafka010] class KafkaMicroBatchReader(
}
/** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */
-private[kafka010] case class KafkaMicroBatchDataReaderFactory(
+private[kafka010] case class KafkaMicroBatchInputPartition(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 871f9700cd1db..c6412eac97dba 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -679,7 +679,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
)
val factories = reader.planUnsafeInputPartitions().asScala
- .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory])
+ .map(_.asInstanceOf[KafkaMicroBatchInputPartition])
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
assert(factories.size == numPartitionsGenerated)
factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) }
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
deleted file mode 100644
index aeb8c1dc342b3..0000000000000
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
+++ /dev/null
@@ -1,226 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka010
-
-import java.{ util => ju }
-
-import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer }
-import org.apache.kafka.common.{ KafkaException, TopicPartition }
-
-import org.apache.spark.internal.Logging
-
-/**
- * Consumer of single topicpartition, intended for cached reuse.
- * Underlying consumer is not threadsafe, so neither is this,
- * but processing the same topicpartition and group id in multiple threads is usually bad anyway.
- */
-private[kafka010]
-class CachedKafkaConsumer[K, V] private(
- val groupId: String,
- val topic: String,
- val partition: Int,
- val kafkaParams: ju.Map[String, Object]) extends Logging {
-
- require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG),
- "groupId used for cache key must match the groupId in kafkaParams")
-
- val topicPartition = new TopicPartition(topic, partition)
-
- protected val consumer = {
- val c = new KafkaConsumer[K, V](kafkaParams)
- val tps = new ju.ArrayList[TopicPartition]()
- tps.add(topicPartition)
- c.assign(tps)
- c
- }
-
- // TODO if the buffer was kept around as a random-access structure,
- // could possibly optimize re-calculating of an RDD in the same batch
- protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]()
- protected var nextOffset = -2L
-
- def close(): Unit = consumer.close()
-
- /**
- * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
- * Sequential forward access will use buffers, but random access will be horribly inefficient.
- */
- def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
- logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset")
- if (offset != nextOffset) {
- logInfo(s"Initial fetch for $groupId $topic $partition $offset")
- seek(offset)
- poll(timeout)
- }
-
- if (!buffer.hasNext()) { poll(timeout) }
- require(buffer.hasNext(),
- s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
- var record = buffer.next()
-
- if (record.offset != offset) {
- logInfo(s"Buffer miss for $groupId $topic $partition $offset")
- seek(offset)
- poll(timeout)
- require(buffer.hasNext(),
- s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
- record = buffer.next()
- require(record.offset == offset,
- s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " +
- s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " +
- "spark.streaming.kafka.allowNonConsecutiveOffsets"
- )
- }
-
- nextOffset = offset + 1
- record
- }
-
- /**
- * Start a batch on a compacted topic
- */
- def compactedStart(offset: Long, timeout: Long): Unit = {
- logDebug(s"compacted start $groupId $topic $partition starting $offset")
- // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics
- if (offset != nextOffset) {
- logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset")
- seek(offset)
- poll(timeout)
- }
- }
-
- /**
- * Get the next record in the batch from a compacted topic.
- * Assumes compactedStart has been called first, and ignores gaps.
- */
- def compactedNext(timeout: Long): ConsumerRecord[K, V] = {
- if (!buffer.hasNext()) {
- poll(timeout)
- }
- require(buffer.hasNext(),
- s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout")
- val record = buffer.next()
- nextOffset = record.offset + 1
- record
- }
-
- /**
- * Rewind to previous record in the batch from a compacted topic.
- * @throws NoSuchElementException if no previous element
- */
- def compactedPrevious(): ConsumerRecord[K, V] = {
- buffer.previous()
- }
-
- private def seek(offset: Long): Unit = {
- logDebug(s"Seeking to $topicPartition $offset")
- consumer.seek(topicPartition, offset)
- }
-
- private def poll(timeout: Long): Unit = {
- val p = consumer.poll(timeout)
- val r = p.records(topicPartition)
- logDebug(s"Polled ${p.partitions()} ${r.size}")
- buffer = r.listIterator
- }
-
-}
-
-private[kafka010]
-object CachedKafkaConsumer extends Logging {
-
- private case class CacheKey(groupId: String, topic: String, partition: Int)
-
- // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
- private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null
-
- /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */
- def init(
- initialCapacity: Int,
- maxCapacity: Int,
- loadFactor: Float): Unit = CachedKafkaConsumer.synchronized {
- if (null == cache) {
- logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
- cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]](
- initialCapacity, loadFactor, true) {
- override def removeEldestEntry(
- entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = {
- if (this.size > maxCapacity) {
- try {
- entry.getValue.consumer.close()
- } catch {
- case x: KafkaException =>
- logError("Error closing oldest Kafka consumer", x)
- }
- true
- } else {
- false
- }
- }
- }
- }
- }
-
- /**
- * Get a cached consumer for groupId, assigned to topic and partition.
- * If matching consumer doesn't already exist, will be created using kafkaParams.
- */
- def get[K, V](
- groupId: String,
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
- CachedKafkaConsumer.synchronized {
- val k = CacheKey(groupId, topic, partition)
- val v = cache.get(k)
- if (null == v) {
- logInfo(s"Cache miss for $k")
- logDebug(cache.keySet.toString)
- val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
- cache.put(k, c)
- c
- } else {
- // any given topicpartition should have a consistent key and value type
- v.asInstanceOf[CachedKafkaConsumer[K, V]]
- }
- }
-
- /**
- * Get a fresh new instance, unassociated with the global cache.
- * Caller is responsible for closing
- */
- def getUncached[K, V](
- groupId: String,
- topic: String,
- partition: Int,
- kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
- new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
-
- /** remove consumer for given groupId, topic, and partition, if it exists */
- def remove(groupId: String, topic: String, partition: Int): Unit = {
- val k = CacheKey(groupId, topic, partition)
- logInfo(s"Removing $k from cache")
- val v = CachedKafkaConsumer.synchronized {
- cache.remove(k)
- }
- if (null != v) {
- v.close()
- logInfo(s"Removed $k from cache")
- }
- }
-}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
new file mode 100644
index 0000000000000..68c5fe9ab066a
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer}
+import org.apache.kafka.common.{KafkaException, TopicPartition}
+
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+
+private[kafka010] sealed trait KafkaDataConsumer[K, V] {
+ /**
+ * Get the record for the given offset if available.
+ *
+ * @param offset the offset to fetch.
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ */
+ def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+ internalConsumer.get(offset, pollTimeoutMs)
+ }
+
+ /**
+ * Start a batch on a compacted topic
+ *
+ * @param offset the offset to fetch.
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ */
+ def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+ internalConsumer.compactedStart(offset, pollTimeoutMs)
+ }
+
+ /**
+ * Get the next record in the batch from a compacted topic.
+ * Assumes compactedStart has been called first, and ignores gaps.
+ *
+ * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka.
+ */
+ def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+ internalConsumer.compactedNext(pollTimeoutMs)
+ }
+
+ /**
+ * Rewind to previous record in the batch from a compacted topic.
+ *
+ * @throws NoSuchElementException if no previous element
+ */
+ def compactedPrevious(): ConsumerRecord[K, V] = {
+ internalConsumer.compactedPrevious()
+ }
+
+ /**
+ * Release this consumer from being further used. Depending on its implementation,
+ * this consumer will be either finalized, or reset for reuse later.
+ */
+ def release(): Unit
+
+ /** Reference to the internal implementation that this wrapper delegates to */
+ def internalConsumer: InternalKafkaConsumer[K, V]
+}
+
+
+/**
+ * A wrapper around Kafka's KafkaConsumer.
+ * This is not for direct use outside this file.
+ */
+private[kafka010] class InternalKafkaConsumer[K, V](
+ val topicPartition: TopicPartition,
+ val kafkaParams: ju.Map[String, Object]) extends Logging {
+
+ private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG)
+ .asInstanceOf[String]
+
+ private val consumer = createConsumer
+
+ /** indicates whether this consumer is in use or not */
+ var inUse = true
+
+ /** indicate whether this consumer is going to be stopped in the next release */
+ var markedForClose = false
+
+ // TODO if the buffer was kept around as a random-access structure,
+ // could possibly optimize re-calculating of an RDD in the same batch
+ @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]()
+ @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET
+
+ override def toString: String = {
+ "InternalKafkaConsumer(" +
+ s"hash=${Integer.toHexString(hashCode)}, " +
+ s"groupId=$groupId, " +
+ s"topicPartition=$topicPartition)"
+ }
+
+ /** Create a KafkaConsumer to fetch records for `topicPartition` */
+ private def createConsumer: KafkaConsumer[K, V] = {
+ val c = new KafkaConsumer[K, V](kafkaParams)
+ val topics = ju.Arrays.asList(topicPartition)
+ c.assign(topics)
+ c
+ }
+
+ def close(): Unit = consumer.close()
+
+ /**
+ * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
+ * Sequential forward access will use buffers, but random access will be horribly inefficient.
+ */
+ def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
+ logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset")
+ if (offset != nextOffset) {
+ logInfo(s"Initial fetch for $groupId $topicPartition $offset")
+ seek(offset)
+ poll(timeout)
+ }
+
+ if (!buffer.hasNext()) {
+ poll(timeout)
+ }
+ require(buffer.hasNext(),
+ s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout")
+ var record = buffer.next()
+
+ if (record.offset != offset) {
+ logInfo(s"Buffer miss for $groupId $topicPartition $offset")
+ seek(offset)
+ poll(timeout)
+ require(buffer.hasNext(),
+ s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout")
+ record = buffer.next()
+ require(record.offset == offset,
+ s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " +
+ s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " +
+ "spark.streaming.kafka.allowNonConsecutiveOffsets"
+ )
+ }
+
+ nextOffset = offset + 1
+ record
+ }
+
+ /**
+ * Start a batch on a compacted topic
+ */
+ def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = {
+ logDebug(s"compacted start $groupId $topicPartition starting $offset")
+ // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics
+ if (offset != nextOffset) {
+ logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset")
+ seek(offset)
+ poll(pollTimeoutMs)
+ }
+ }
+
+ /**
+ * Get the next record in the batch from a compacted topic.
+ * Assumes compactedStart has been called first, and ignores gaps.
+ */
+ def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = {
+ if (!buffer.hasNext()) {
+ poll(pollTimeoutMs)
+ }
+ require(buffer.hasNext(),
+ s"Failed to get records for compacted $groupId $topicPartition " +
+ s"after polling for $pollTimeoutMs")
+ val record = buffer.next()
+ nextOffset = record.offset + 1
+ record
+ }
+
+ /**
+ * Rewind to previous record in the batch from a compacted topic.
+ * @throws NoSuchElementException if no previous element
+ */
+ def compactedPrevious(): ConsumerRecord[K, V] = {
+ buffer.previous()
+ }
+
+ private def seek(offset: Long): Unit = {
+ logDebug(s"Seeking to $topicPartition $offset")
+ consumer.seek(topicPartition, offset)
+ }
+
+ private def poll(timeout: Long): Unit = {
+ val p = consumer.poll(timeout)
+ val r = p.records(topicPartition)
+ logDebug(s"Polled ${p.partitions()} ${r.size}")
+ buffer = r.listIterator
+ }
+
+}
+
+private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition)
+
+private[kafka010] object KafkaDataConsumer extends Logging {
+
+ private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V])
+ extends KafkaDataConsumer[K, V] {
+ assert(internalConsumer.inUse)
+ override def release(): Unit = KafkaDataConsumer.release(internalConsumer)
+ }
+
+ private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V])
+ extends KafkaDataConsumer[K, V] {
+ override def release(): Unit = internalConsumer.close()
+ }
+
+ // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
+ private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null
+
+ /**
+ * Must be called before acquire, once per JVM, to configure the cache.
+ * Further calls are ignored.
+ */
+ def init(
+ initialCapacity: Int,
+ maxCapacity: Int,
+ loadFactor: Float): Unit = synchronized {
+ if (null == cache) {
+ logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
+ cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]](
+ initialCapacity, loadFactor, true) {
+ override def removeEldestEntry(
+ entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = {
+
+ // Try to remove the least-used entry if its currently not in use.
+ //
+ // If you cannot remove it, then the cache will keep growing. In the worst case,
+ // the cache will grow to the max number of concurrent tasks that can run in the executor,
+ // (that is, number of tasks slots) after which it will never reduce. This is unlikely to
+ // be a serious problem because an executor with more than 64 (default) tasks slots is
+ // likely running on a beefy machine that can handle a large number of simultaneously
+ // active consumers.
+
+ if (entry.getValue.inUse == false && this.size > maxCapacity) {
+ logWarning(
+ s"KafkaConsumer cache hitting max capacity of $maxCapacity, " +
+ s"removing consumer for ${entry.getKey}")
+ try {
+ entry.getValue.close()
+ } catch {
+ case x: KafkaException =>
+ logError("Error closing oldest Kafka consumer", x)
+ }
+ true
+ } else {
+ false
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Get a cached consumer for groupId, assigned to topic and partition.
+ * If matching consumer doesn't already exist, will be created using kafkaParams.
+ * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]].
+ *
+ * Note: This method guarantees that the consumer returned is not currently in use by anyone
+ * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by
+ * caching them and tracking when they are in use.
+ */
+ def acquire[K, V](
+ topicPartition: TopicPartition,
+ kafkaParams: ju.Map[String, Object],
+ context: TaskContext,
+ useCache: Boolean): KafkaDataConsumer[K, V] = synchronized {
+ val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+ val key = new CacheKey(groupId, topicPartition)
+ val existingInternalConsumer = cache.get(key)
+
+ lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams)
+
+ if (context != null && context.attemptNumber >= 1) {
+ // If this is reattempt at running the task, then invalidate cached consumers if any and
+ // start with a new one. If prior attempt failures were cache related then this way old
+ // problematic consumers can be removed.
+ logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer")
+ if (existingInternalConsumer != null) {
+ // Consumer exists in cache. If its in use, mark it for closing later, or close it now.
+ if (existingInternalConsumer.inUse) {
+ existingInternalConsumer.markedForClose = true
+ } else {
+ existingInternalConsumer.close()
+ // Remove the consumer from cache only if it's closed.
+ // Marked for close consumers will be removed in release function.
+ cache.remove(key)
+ }
+ }
+
+ logDebug("Reattempt detected, new non-cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+ } else if (!useCache) {
+ // If consumer reuse turned off, then do not use it, return a new consumer
+ logDebug("Cache usage turned off, new non-cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+ } else if (existingInternalConsumer == null) {
+ // If consumer is not already cached, then put a new in the cache and return it
+ logDebug("No cached consumer, new cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ cache.put(key, newInternalConsumer)
+ CachedKafkaDataConsumer(newInternalConsumer)
+ } else if (existingInternalConsumer.inUse) {
+ // If consumer is already cached but is currently in use, then return a new consumer
+ logDebug("Used cached consumer found, new non-cached consumer will be allocated " +
+ s"$newInternalConsumer")
+ NonCachedKafkaDataConsumer(newInternalConsumer)
+ } else {
+ // If consumer is already cached and is currently not in use, then return that consumer
+ logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer")
+ existingInternalConsumer.inUse = true
+ // Any given TopicPartition should have a consistent key and value type
+ CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]])
+ }
+ }
+
+ private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized {
+ // Clear the consumer from the cache if this is indeed the consumer present in the cache
+ val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition)
+ val cachedInternalConsumer = cache.get(key)
+ if (internalConsumer.eq(cachedInternalConsumer)) {
+ // The released consumer is the same object as the cached one.
+ if (internalConsumer.markedForClose) {
+ internalConsumer.close()
+ cache.remove(key)
+ } else {
+ internalConsumer.inUse = false
+ }
+ } else {
+ // The released consumer is either not the same one as in the cache, or not in the cache
+ // at all. This may happen if the cache was invalidate while this consumer was being used.
+ // Just close this consumer.
+ internalConsumer.close()
+ logInfo(s"Released a supposedly cached consumer that was not found in the cache " +
+ s"$internalConsumer")
+ }
+ }
+}
+
+private[kafka010] object InternalKafkaConsumer {
+ private val UNKNOWN_OFFSET = -2L
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
index 07239eda64d2e..81abc9860bfc3 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010
import java.{ util => ju }
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord }
import org.apache.kafka.common.TopicPartition
@@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V](
cacheLoadFactor: Float
) extends Iterator[ConsumerRecord[K, V]] {
- val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
-
context.addTaskCompletionListener(_ => closeIfNeeded())
- val consumer = if (useConsumerCache) {
- CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
- if (context.attemptNumber >= 1) {
- // just in case the prior attempt failures were cache related
- CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
- }
- CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams)
- } else {
- CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams)
+ val consumer = {
+ KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
+ KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache)
}
var requestOffset = part.fromOffset
def closeIfNeeded(): Unit = {
- if (!useConsumerCache && consumer != null) {
- consumer.close()
+ if (consumer != null) {
+ consumer.release()
}
}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
new file mode 100644
index 0000000000000..d934c64962adb
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.util.concurrent.{Executors, TimeUnit}
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer.ConsumerConfig._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.ByteArrayDeserializer
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+
+class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll {
+ private var testUtils: KafkaTestUtils = _
+ private val topic = "topic" + Random.nextInt()
+ private val topicPartition = new TopicPartition(topic, 0)
+ private val groupId = "groupId"
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ testUtils = new KafkaTestUtils
+ testUtils.setup()
+ KafkaDataConsumer.init(16, 64, 0.75f)
+ }
+
+ override def afterAll(): Unit = {
+ if (testUtils != null) {
+ testUtils.teardown()
+ testUtils = null
+ }
+ super.afterAll()
+ }
+
+ private def getKafkaParams() = Map[String, Object](
+ GROUP_ID_CONFIG -> groupId,
+ BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress,
+ KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+ VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName,
+ AUTO_OFFSET_RESET_CONFIG -> "earliest",
+ ENABLE_AUTO_COMMIT_CONFIG -> "false"
+ ).asJava
+
+ test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") {
+ KafkaDataConsumer.cache.clear()
+
+ val kafkaParams = getKafkaParams()
+
+ val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+ topicPartition, kafkaParams, null, true)
+ consumer1.release()
+
+ val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+ topicPartition, kafkaParams, null, true)
+ consumer2.release()
+
+ assert(KafkaDataConsumer.cache.size() == 1)
+ val key = new CacheKey(groupId, topicPartition)
+ val existingInternalConsumer = KafkaDataConsumer.cache.get(key)
+ assert(existingInternalConsumer.eq(consumer1.internalConsumer))
+ assert(existingInternalConsumer.eq(consumer2.internalConsumer))
+ }
+
+ test("concurrent use of KafkaDataConsumer") {
+ val data = (1 to 1000).map(_.toString)
+ testUtils.createTopic(topic)
+ testUtils.sendMessages(topic, data.toArray)
+
+ val kafkaParams = getKafkaParams()
+
+ val numThreads = 100
+ val numConsumerUsages = 500
+
+ @volatile var error: Throwable = null
+
+ def consume(i: Int): Unit = {
+ val useCache = Random.nextBoolean
+ val taskContext = if (Random.nextBoolean) {
+ new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null)
+ } else {
+ null
+ }
+ val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]](
+ topicPartition, kafkaParams, taskContext, useCache)
+ try {
+ val rcvd = (0 until data.length).map { offset =>
+ val bytes = consumer.get(offset, 10000).value()
+ new String(bytes)
+ }
+ assert(rcvd == data)
+ } catch {
+ case e: Throwable =>
+ error = e
+ throw e
+ } finally {
+ consumer.release()
+ }
+ }
+
+ val threadPool = Executors.newFixedThreadPool(numThreads)
+ try {
+ val futures = (1 to numConsumerUsages).map { i =>
+ threadPool.submit(new Runnable {
+ override def run(): Unit = { consume(i) }
+ })
+ }
+ futures.foreach(_.get(1, TimeUnit.MINUTES))
+ assert(error == null)
+ } finally {
+ threadPool.shutdown()
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 3fb6d1e4e4f3e..337133a2e2326 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") (
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setValidationIndicatorCol(value: String): this.type = {
+ set(validationIndicatorCol, value)
+ }
+
override protected def train(dataset: Dataset[_]): GBTClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
+
+ val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
+
// We copy and modify this from Classifier.extractLabeledPoints since GBT only supports
// 2 classes now. This lets us provide a more precise error message.
- val oldDataset: RDD[LabeledPoint] =
+ val convert2LabeledPoint = (dataset: Dataset[_]) => {
dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label == 0 || label == 1, s"GBTClassifier was given" +
@@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") (
s" GBTClassifier currently only supports binary classification.")
LabeledPoint(label, features)
}
- val numFeatures = oldDataset.first().features.size
+ }
+
+ val (trainDataset, validationDataset) = if (withValidation) {
+ (
+ convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))),
+ convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol))))
+ )
+ } else {
+ (convert2LabeledPoint(dataset), null)
+ }
+
+ val numFeatures = trainDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
val numClasses = 2
@@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
- seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
+ seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy,
+ validationIndicatorCol)
instr.logNumFeatures(numFeatures)
instr.logNumClasses(numClasses)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
- $(seed), $(featureSubsetStrategy))
+ val (baseLearners, learnerWeights) = if (withValidation) {
+ GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy))
+ } else {
+ GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy))
+ }
+
val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 438e53ba6197c..1ad4e097246a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -261,8 +261,9 @@ class BisectingKMeans @Since("2.0.0") (
transformSchema(dataset.schema, logging = true)
val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
- val instr = Instrumentation.create(this, rdd)
- instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(featuresCol, predictionCol, k, maxIter, seed,
+ minDivisibleClusterSize, distanceMeasure)
val bkm = new MLlibBisectingKMeans()
.setK($(k))
@@ -275,6 +276,8 @@ class BisectingKMeans @Since("2.0.0") (
val summary = new BisectingKMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
+ // TODO: need to extend logNamedValue to support Array
+ instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]"))
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88d618c3a03a8..3091bb5a2e54c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -352,7 +352,7 @@ class GaussianMixture @Since("2.0.0") (
s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" +
s" matrix is quadratic in the number of features.")
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol)
instr.logNumFeatures(numFeatures)
@@ -425,6 +425,9 @@ class GaussianMixture @Since("2.0.0") (
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood)
model.setSummary(Some(summary))
+ instr.logNamedValue("logLikelihood", logLikelihood)
+ // TODO: need to extend logNamedValue to support Array
+ instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]"))
instr.logSuccess(model)
model
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 97f246fbfd859..e72d7f9485e6a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -342,7 +342,7 @@ class KMeans @Since("1.5.0") (
instances.persist(StorageLevel.MEMORY_AND_DISK)
}
- val instr = Instrumentation.create(this, instances)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
val algo = new MLlibKMeans()
@@ -359,6 +359,8 @@ class KMeans @Since("1.5.0") (
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
+ // TODO: need to extend logNamedValue to support Array
+ instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]"))
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index afe599cd167cb..fed42c959b5ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -569,10 +569,14 @@ abstract class LDAModel private[ml] (
class LocalLDAModel private[ml] (
uid: String,
vocabSize: Int,
- @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel,
+ private[clustering] val oldLocalModel_ : OldLocalLDAModel,
sparkSession: SparkSession)
extends LDAModel(uid, vocabSize, sparkSession) {
+ override private[clustering] def oldLocalModel: OldLocalLDAModel = {
+ oldLocalModel_.setSeed(getSeed)
+ }
+
@Since("1.6.0")
override def copy(extra: ParamMap): LocalLDAModel = {
val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 0bf405d9abf9d..d7fbe28ae7a64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -161,6 +161,8 @@ class FPGrowth @Since("2.2.0") (
private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
val handlePersistence = dataset.storageLevel == StorageLevel.NONE
+ val instr = Instrumentation.create(this, dataset)
+ instr.logParams(params: _*)
val data = dataset.select($(itemsCol))
val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray)
val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
@@ -183,7 +185,9 @@ class FPGrowth @Since("2.2.0") (
items.unpersist()
}
- copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
+ val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
+ instr.logSuccess(model)
+ model
}
@Since("2.2.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index b9c3170cc3c28..7e08675f834da 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" +
" and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"),
isValid = "(value: String) => " +
- "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)")
+ "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"),
+ ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " +
+ "each row is for training or for validation. False indicates training; true indicates " +
+ "validation.")
)
val code = genSharedParams(params)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 282ea6ebcbf7f..5928a0749f738 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params {
/** @group getParam */
final def getDistanceMeasure: String = $(distanceMeasure)
}
+
+/**
+ * Trait for shared param validationIndicatorCol. This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasValidationIndicatorCol extends Params {
+
+ /**
+ * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation..
+ * @group param
+ */
+ final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.")
+
+ /** @group getParam */
+ final def getValidationIndicatorCol: String = $(validationIndicatorCol)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index d7e054bf55ef6..eb8b3c001436a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setValidationIndicatorCol(value: String): this.type = {
+ set(validationIndicatorCol, value)
+ }
+
override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
- val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
- val numFeatures = oldDataset.first().features.size
+
+ val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty
+
+ val (trainDataset, validationDataset) = if (withValidation) {
+ (
+ extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))),
+ extractLabeledPoints(dataset.filter(col($(validationIndicatorCol))))
+ )
+ } else {
+ (extractLabeledPoints(dataset), null)
+ }
+ val numFeatures = trainDataset.first().features.size
val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
- val instr = Instrumentation.create(this, oldDataset)
+ val instr = Instrumentation.create(this, dataset)
instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType,
maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode,
seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy)
instr.logNumFeatures(numFeatures)
- val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy,
- $(seed), $(featureSubsetStrategy))
+ val (baseLearners, learnerWeights) = if (withValidation) {
+ GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy))
+ } else {
+ GradientBoostedTrees.run(trainDataset, boostingStrategy,
+ $(seed), $(featureSubsetStrategy))
+ }
val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
instr.logSuccess(m)
m
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index ec8868bb42cbb..00157fe63af41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -21,6 +21,7 @@ import java.util.Locale
import scala.util.Try
+import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize {
-
- /* TODO: Add this doc when we add this param. SPARK-7132
- * Threshold for stopping early when runWithValidation is used.
- * If the error rate on the validation input changes by less than the validationTol,
- * then learning will stop early (before [[numIterations]]).
- * This parameter is ignored when run is used.
- * (default = 1e-5)
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize
+ with HasValidationIndicatorCol {
+
+ /**
+ * Threshold for stopping early when fit with validation is used.
+ * (This parameter is ignored when fit without validation is used.)
+ * The decision to stop early is decided based on this logic:
+ * If the current loss on the validation set is greater than 0.01, the diff
+ * of validation error is compared to relative tolerance which is
+ * validationTol * (current loss on the validation set).
+ * If the current loss on the validation set is less than or equal to 0.01,
+ * the diff of validation error is compared to absolute tolerance which is
+ * validationTol * 0.01.
* @group param
+ * @see validationIndicatorCol
*/
- // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
- // validationTol -> 1e-5
+ @Since("2.4.0")
+ final val validationTol: DoubleParam = new DoubleParam(this, "validationTol",
+ "Threshold for stopping early when fit with validation is used." +
+ "If the error rate on the validation input changes by less than the validationTol," +
+ "then learning will stop early (before `maxIter`)." +
+ "This parameter is ignored when fit without validation is used.",
+ ParamValidators.gtEq(0.0)
+ )
+
+ /** @group getParam */
+ @Since("2.4.0")
+ final def getValidationTol: Double = $(validationTol)
/**
* @deprecated This method is deprecated and will be removed in 3.0.0.
@@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
@deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0")
def setStepSize(value: Double): this.type = set(stepSize, value)
- setDefault(maxIter -> 20, stepSize -> 0.1)
+ setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01)
setDefault(featureSubsetStrategy -> "all")
@@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
// NOTE: The old API does not support "seed" so we ignore it.
- new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+ new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol)
}
/** Get old Gradient Boosting Loss type */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5e916cc4a9fdd..f327f37bad204 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
- logDebug(s"Train split $splitIndex with multiple sets of parameters.")
+ instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.")
// Fit models in a Future for training in parallel
val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
@@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
}
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
- logDebug(s"Got metric $metric for model trained with $paramMap.")
+ instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
@@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
foldMetrics
}.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits
- logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
+ instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
- logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
- logInfo(s"Best cross-validation metric: $bestMetric.")
+ instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ instr.logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 13369c4df7180..14d6a69c36747 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
} else None
// Fit models in a Future for training in parallel
- logDebug(s"Train split with multiple sets of parameters.")
+ instr.logDebug(s"Train split with multiple sets of parameters.")
val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) =>
Future[Double] {
val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
@@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
}
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(model.transform(validationDataset, paramMap))
- logDebug(s"Got metric $metric for model trained with $paramMap.")
+ instr.logDebug(s"Got metric $metric for model trained with $paramMap.")
metric
} (executionContext)
}
@@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
trainingDataset.unpersist()
validationDataset.unpersist()
- logInfo(s"Train validation split metrics: ${metrics.toSeq}")
+ instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
- logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
- logInfo(s"Best train validation split metric: $bestMetric.")
+ instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
+ instr.logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
index 3247c394dfa64..467130b37c16e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala
@@ -58,6 +58,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private (
s" storageLevel=${dataset.getStorageLevel}")
}
+ /**
+ * Logs a debug message with a prefix that uniquely identifies the training session.
+ */
+ override def logDebug(msg: => String): Unit = {
+ super.logDebug(prefix + msg)
+ }
+
/**
* Logs a warning message with a prefix that uniquely identifies the training session.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b8a6e94248421..f915062d77389 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.util.BoundedPriorityQueue
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
/**
* Latent Dirichlet Allocation (LDA) model.
@@ -194,6 +194,8 @@ class LocalLDAModel private[spark] (
override protected[spark] val gammaShape: Double = 100)
extends LDAModel with Serializable {
+ private var seed: Long = Utils.random.nextLong()
+
@Since("1.3.0")
override def k: Int = topics.numCols
@@ -216,6 +218,21 @@ class LocalLDAModel private[spark] (
override protected def formatVersion = "1.0"
+ /**
+ * Random seed for cluster initialization.
+ */
+ @Since("2.4.0")
+ def getSeed: Long = seed
+
+ /**
+ * Set the random seed for cluster initialization.
+ */
+ @Since("2.4.0")
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
@Since("1.5.0")
override def save(sc: SparkContext, path: String): Unit = {
LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration,
@@ -298,6 +315,7 @@ class LocalLDAModel private[spark] (
// by topic (columns of lambda)
val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t
val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta)
+ val gammaSeed = this.seed
// Sum bound components for each document:
// component for prob(tokens) + component for prob(document-topic distribution)
@@ -306,7 +324,7 @@ class LocalLDAModel private[spark] (
val localElogbeta = ElogbetaBc.value
var docBound = 0.0D
val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, exp(localElogbeta), brzAlpha, gammaShape, k)
+ termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id)
val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad)
// E[log p(doc | theta, beta)]
@@ -352,6 +370,7 @@ class LocalLDAModel private[spark] (
val docConcentrationBrz = this.docConcentration.asBreeze
val gammaShape = this.gammaShape
val k = this.k
+ val gammaSeed = this.seed
documents.map { case (id: Long, termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
@@ -362,7 +381,8 @@ class LocalLDAModel private[spark] (
expElogbetaBc.value,
docConcentrationBrz,
gammaShape,
- k)
+ k,
+ gammaSeed + id)
(id, Vectors.dense(normalize(gamma, 1.0).toArray))
}
}
@@ -376,6 +396,7 @@ class LocalLDAModel private[spark] (
val docConcentrationBrz = this.docConcentration.asBreeze
val gammaShape = this.gammaShape
val k = this.k
+ val gammaSeed = this.seed
(termCounts: Vector) =>
if (termCounts.numNonzeros == 0) {
@@ -386,7 +407,8 @@ class LocalLDAModel private[spark] (
expElogbeta,
docConcentrationBrz,
gammaShape,
- k)
+ k,
+ gammaSeed)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
@@ -403,6 +425,7 @@ class LocalLDAModel private[spark] (
*/
@Since("2.0.0")
def topicDistribution(document: Vector): Vector = {
+ val gammaSeed = this.seed
val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t)
if (document.numNonzeros == 0) {
Vectors.zeros(this.k)
@@ -412,7 +435,8 @@ class LocalLDAModel private[spark] (
expElogbeta,
this.docConcentration.asBreeze,
gammaShape,
- this.k)
+ this.k,
+ gammaSeed)
Vectors.dense(normalize(gamma, 1.0).toArray)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 693a2a31f026b..f8e5f3ed76457 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
val alpha = this.alpha.asBreeze
val gammaShape = this.gammaShape
val optimizeDocConcentration = this.optimizeDocConcentration
+ val seed = randomGenerator.nextLong()
// If and only if optimizeDocConcentration is set true,
// we calculate logphat in the same pass as other statistics.
// No calculation of loghat happens otherwise.
@@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
None
}
- val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs =>
- val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
-
- val stat = BDM.zeros[Double](k, vocabSize)
- val logphatPartOption = logphatPartOptionBase()
- var nonEmptyDocCount: Long = 0L
- nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
- nonEmptyDocCount += 1
- val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference(
- termCounts, expElogbetaBc.value, alpha, gammaShape, k)
- stat(::, ids) := stat(::, ids) + sstats
- logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad))
- }
- Iterator((stat, logphatPartOption, nonEmptyDocCount))
+ val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex {
+ (index, docs) =>
+ val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0)
+
+ val stat = BDM.zeros[Double](k, vocabSize)
+ val logphatPartOption = logphatPartOptionBase()
+ var nonEmptyDocCount: Long = 0L
+ nonEmptyDocs.foreach { case (_, termCounts: Vector) =>
+ nonEmptyDocCount += 1
+ val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference(
+ termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index)
+ stat(::, ids) := stat(::, ids) + sstats
+ logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad))
+ }
+ Iterator((stat, logphatPartOption, nonEmptyDocCount))
}
val elementWiseSum = (
@@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging {
}
override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = {
- new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape)
+ new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta)
+ .setSeed(randomGenerator.nextLong())
}
}
@@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer {
expElogbeta: BDM[Double],
alpha: breeze.linalg.Vector[Double],
gammaShape: Double,
- k: Int): (BDV[Double], BDM[Double], List[Int]) = {
+ k: Int,
+ seed: Long): (BDV[Double], BDM[Double], List[Int]) = {
val (ids: List[Int], cts: Array[Double]) = termCounts match {
case v: DenseVector => ((0 until v.size).toList, v.values)
case v: SparseVector => (v.indices.toList, v.values)
}
// Initialize the variational distribution q(theta|gamma) for the mini-batch
+ val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed))
val gammad: BDV[Double] =
- new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K
+ new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K
val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K
val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K
- val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
+ val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
var meanGammaChange = 1D
val ctsVector = new BDV[Double](cts) // ids
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 3ca75e8cdb97a..7a5e520d5818e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom
* $$
* \begin{align}
* c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\
- * n_t+t &= n_t * a + m_t
+ * n_t+1 &= n_t * a + m_t
* \end{align}
* $$
*
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index e20de196d65ca..e6d2a8e2b900e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
import org.apache.spark.util.Utils
/**
@@ -392,6 +393,51 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
assert(evalArr(2) ~== lossErr3 relTol 1E-3)
}
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val validationIndicatorCol = "validationIndicator"
+ val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false))
+ val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true))
+
+ val numIter = 20
+ for (lossType <- GBTClassifier.supportedLossTypes) {
+ val gbt = new GBTClassifier()
+ .setSeed(123)
+ .setMaxDepth(2)
+ .setLossType(lossType)
+ .setMaxIter(numIter)
+ val modelWithoutValidation = gbt.fit(trainDF)
+
+ gbt.setValidationIndicatorCol(validationIndicatorCol)
+ val modelWithValidation = gbt.fit(trainDF.union(validationDF))
+
+ assert(modelWithoutValidation.numTrees === numIter)
+ // early stop
+ assert(modelWithValidation.numTrees < numIter)
+
+ val (errorWithoutValidation, errorWithValidation) = {
+ val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType),
+ GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees,
+ modelWithValidation.treeWeights, modelWithValidation.getOldLossType))
+ }
+ assert(errorWithValidation < errorWithoutValidation)
+
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validationData, modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
+ OldAlgo.Classification)
+ assert(evaluationArray.length === numIter)
+ assert(evaluationArray(modelWithValidation.numTrees) >
+ evaluationArray(modelWithValidation.numTrees - 1))
+ var i = 1
+ while (i < modelWithValidation.numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index f3ff2afcad2cd..81842afbddbbb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -19,17 +19,18 @@ package org.apache.spark.ml.clustering
import scala.language.existentials
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.clustering.DistanceMeasure
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.Dataset
-class BisectingKMeansSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest {
+
+ import testImplicits._
final val k = 5
@transient var dataset: Dataset[_] = _
@@ -68,10 +69,13 @@ class BisectingKMeansSuite
// Verify fit does not fail on very sparse data
val model = bkm.fit(sparseDataset)
- val result = model.transform(sparseDataset)
- val numClusters = result.select("prediction").distinct().collect().length
- // Verify we hit the edge case
- assert(numClusters < k && numClusters > 1)
+
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") {
+ rows =>
+ val numClusters = rows.distinct.length
+ // Verify we hit the edge case
+ assert(numClusters < k && numClusters > 1)
+ }
}
test("setter/getter") {
@@ -104,19 +108,16 @@ class BisectingKMeansSuite
val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = bkm.fit(dataset)
assert(model.clusterCenters.length === k)
-
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
- }
- val clusters =
- transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
- assert(clusters.size === k)
- assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model,
+ "features", predictionColName) { rows =>
+ val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
+ }
+
// Check validity of model summary
val numRows = dataset.count()
assert(model.hasSummary)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index d0d461a42711a..0b91f502f615b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -23,16 +23,15 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.stat.distribution.MultivariateGaussian
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Dataset, Row}
-class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
- with DefaultReadWriteTest {
- import testImplicits._
+class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest {
+
import GaussianMixtureSuite._
+ import testImplicits._
final val k = 5
private val seed = 538009335
@@ -119,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
assert(model.weights.length === k)
assert(model.gaussians.length === k)
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName, probabilityColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
- }
-
// Check prediction matches the highest probability, and probabilities sum to one.
- transformed.select(predictionColName, probabilityColName).collect().foreach {
- case Row(pred: Int, prob: Vector) =>
+ testTransformer[Tuple1[Vector]](dataset.toDF(), model,
+ "features", predictionColName, probabilityColName) {
+ case Row(_, pred: Int, prob: Vector) =>
val probArray = prob.toArray
val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2
assert(pred === predFromProb)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 680a7c2034083..2569e7a432ca4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -22,20 +22,21 @@ import scala.util.Random
import org.dmg.pmml.{ClusteringModel, PMML}
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.SparkException
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans,
+ KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
private[clustering] case class TestRow(features: Vector)
-class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
- with PMMLReadWriteTest {
+class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest {
+
+ import testImplicits._
final val k = 5
@transient var dataset: Dataset[_] = _
@@ -109,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
val model = kmeans.fit(dataset)
assert(model.clusterCenters.length === k)
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", predictionColName)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
+ testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model,
+ "features", predictionColName) { rows =>
+ val clusters = rows.map(_.getAs[Int](predictionColName)).toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
}
- val clusters =
- transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet
- assert(clusters.size === k)
- assert(clusters === Set(0, 1, 2, 3, 4))
+
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
@@ -149,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName)
val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName))
- Seq(featuresColName, predictionColName).foreach { column =>
- assert(transformed.columns.contains(column))
- }
+ assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName))
assert(model.getFeaturesCol == featuresColName)
assert(model.getPredictionCol == predictionColName)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 8d728f063dd8c..096b5416899e1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -21,11 +21,9 @@ import scala.language.existentials
import org.apache.hadoop.fs.Path
-import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql._
object LDASuite {
@@ -61,7 +59,7 @@ object LDASuite {
}
-class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LDASuite extends MLTest with DefaultReadWriteTest {
import testImplicits._
@@ -186,16 +184,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
assert(model.topicsMatrix.numCols === k)
assert(!model.isDistributed)
- // transform()
- val transformed = model.transform(dataset)
- val expectedColumns = Array("features", lda.getTopicDistributionCol)
- expectedColumns.foreach { column =>
- assert(transformed.columns.contains(column))
- }
- transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
- val topicDistribution = r.getAs[Vector](0)
- assert(topicDistribution.size === k)
- assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
+ testTransformer[Tuple1[Vector]](dataset.toDF(), model,
+ "features", lda.getTopicDistributionCol) {
+ case Row(_, topicDistribution: Vector) =>
+ assert(topicDistribution.size === k)
+ assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
}
// logLikelihood, logPerplexity
@@ -253,6 +246,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val lda = new LDA()
testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
LDASuite.allParamSettings, checkModelData)
+
+ // Make sure the result is deterministic after saving and loading the model
+ val model = lda.fit(dataset)
+ val model2 = testDefaultReadWrite(model)
+ assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6)
+ assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6)
}
test("read/write DistributedLDAModel") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 773f6d2c542fe..b145c7a3dc952 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees =>
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.lit
import org.apache.spark.util.Utils
/**
@@ -231,7 +232,52 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
}
}
- /////////////////////////////////////////////////////////////////////////////
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ val validationIndicatorCol = "validationIndicator"
+ val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false))
+ val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true))
+
+ val numIter = 20
+ for (lossType <- GBTRegressor.supportedLossTypes) {
+ val gbt = new GBTRegressor()
+ .setSeed(123)
+ .setMaxDepth(2)
+ .setLossType(lossType)
+ .setMaxIter(numIter)
+ val modelWithoutValidation = gbt.fit(trainDF)
+
+ gbt.setValidationIndicatorCol(validationIndicatorCol)
+ val modelWithValidation = gbt.fit(trainDF.union(validationDF))
+
+ assert(modelWithoutValidation.numTrees === numIter)
+ // early stop
+ assert(modelWithValidation.numTrees < numIter)
+
+ val errorWithoutValidation = GradientBoostedTrees.computeError(validationData,
+ modelWithoutValidation.trees, modelWithoutValidation.treeWeights,
+ modelWithoutValidation.getOldLossType)
+ val errorWithValidation = GradientBoostedTrees.computeError(validationData,
+ modelWithValidation.trees, modelWithValidation.treeWeights,
+ modelWithValidation.getOldLossType)
+
+ assert(errorWithValidation < errorWithoutValidation)
+
+ val evaluationArray = GradientBoostedTrees
+ .evaluateEachIteration(validationData, modelWithoutValidation.trees,
+ modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
+ OldAlgo.Regression)
+ assert(evaluationArray.length === numIter)
+ assert(evaluationArray(modelWithValidation.numTrees) >
+ evaluationArray(modelWithValidation.numTrees - 1))
+ var i = 1
+ while (i < modelWithValidation.numTrees) {
+ assert(evaluationArray(i) <= evaluationArray(i - 1))
+ i += 1
+ }
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 7d0e88ee20c3f..4f6d5ff898681 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,11 @@ object MimaExcludes {
// Exclude rules for 2.4.x
lazy val v24excludes = v23excludes ++ Seq(
+ // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"),
+
// [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"),
@@ -73,7 +78,18 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"),
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this")
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"),
+
+ // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="),
+ ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol")
)
// Exclude rules for 2.3.x
diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py
index ea845b98b3db2..88519d7311fcc 100644
--- a/python/pyspark/cloudpickle.py
+++ b/python/pyspark/cloudpickle.py
@@ -272,7 +272,7 @@ def save_memoryview(self, obj):
if not PY3:
def save_buffer(self, obj):
self.save(str(obj))
- dispatch[buffer] = save_buffer
+ dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3
def save_unsupported(self, obj):
raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj))
@@ -801,10 +801,10 @@ def save_ellipsis(self, obj):
def save_not_implemented(self, obj):
self.save_reduce(_gen_not_implemented, ())
- if PY3:
- dispatch[io.TextIOWrapper] = save_file
- else:
+ try: # Python 2
dispatch[file] = save_file
+ except NameError: # Python 3
+ dispatch[io.TextIOWrapper] = save_file
dispatch[type(Ellipsis)] = save_ellipsis
dispatch[type(NotImplemented)] = save_not_implemented
@@ -819,6 +819,11 @@ def save_logger(self, obj):
dispatch[logging.Logger] = save_logger
+ def save_root_logger(self, obj):
+ self.save_reduce(logging.getLogger, (), obj=obj)
+
+ dispatch[logging.RootLogger] = save_root_logger
+
"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index dbb463f6005a1..ede3b6af0a8cf 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
- if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
- self._python_includes.append(filename)
- sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
+ try:
+ filepath = os.path.join(SparkFiles.getRootDirectory(), filename)
+ if not os.path.exists(filepath):
+ # In case of YARN with shell mode, 'spark.submit.pyFiles' files are
+ # not added via SparkContext.addFile. Here we check if the file exists,
+ # try to copy and then add it to the path. See SPARK-21945.
+ shutil.copyfile(path, filepath)
+ if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
+ self._python_includes.append(filename)
+ sys.path.insert(1, filepath)
+ except Exception:
+ warnings.warn(
+ "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to "
+ "Python path:\n %s" % (path, "\n ".join(sys.path)),
+ RuntimeWarning)
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ec17653a1adf9..424ecfd89b060 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],
+ ... ["indexed", "features"])
+ >>> model.evaluateEachIteration(validation)
+ [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
.. versionadded:: 1.4.0
"""
@@ -1319,6 +1323,17 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+ @since("2.4.0")
+ def evaluateEachIteration(self, dataset):
+ """
+ Method to compute error or loss for every iteration of gradient boosting.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ """
+ return self._call_java("evaluateEachIteration", dataset)
+
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 9a66d87d7f211..dd0b62f184d26 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
True
>>> model.trees
[DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...]
+ >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))],
+ ... ["label", "features"])
+ >>> model.evaluateEachIteration(validation, "squared")
+ [0.0, 0.0, 0.0, 0.0, 0.0]
.. versionadded:: 1.4.0
"""
@@ -1156,6 +1160,20 @@ def trees(self):
"""Trees in this ensemble. Warning: These have null parent Estimators."""
return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))]
+ @since("2.4.0")
+ def evaluateEachIteration(self, dataset, loss):
+ """
+ Method to compute error or loss for every iteration of gradient boosting.
+
+ :param dataset:
+ Test dataset to evaluate model on, where dataset is an
+ instance of :py:class:`pyspark.sql.DataFrame`
+ :param loss:
+ The loss function used to compute error.
+ Supported options: squared, absolute
+ """
+ return self._call_java("evaluateEachIteration", dataset, loss)
+
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 093593132e56d..0dde0db9e3339 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1595,6 +1595,44 @@ def test_default_read_write(self):
self.assertEqual(lr.uid, lr3.uid)
self.assertEqual(lr.extractParamMap(), lr3.extractParamMap())
+ def test_default_read_write_default_params(self):
+ lr = LogisticRegression()
+ self.assertFalse(lr.isSet(lr.getParam("threshold")))
+
+ lr.setMaxIter(50)
+ lr.setThreshold(.75)
+
+ # `threshold` is set by user, default param `predictionCol` is not set by user.
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ writer = DefaultParamsWriter(lr)
+ metadata = json.loads(writer._get_metadata_to_save(lr, self.sc))
+ self.assertTrue("defaultParamMap" in metadata)
+
+ reader = DefaultParamsReadable.read()
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ self.assertTrue(lr.isSet(lr.getParam("threshold")))
+ self.assertFalse(lr.isSet(lr.getParam("predictionCol")))
+ self.assertTrue(lr.hasDefault(lr.getParam("predictionCol")))
+
+ # manually create metadata without `defaultParamMap` section.
+ del metadata['defaultParamMap']
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
+ reader.getAndSetParams(lr, loadedMetadata)
+
+ # Prior to 2.4.0, metadata doesn't have `defaultParamMap`.
+ metadata['sparkVersion'] = '2.3.0'
+ metadataStr = json.dumps(metadata, separators=[',', ':'])
+ loadedMetadata = reader._parseMetaData(metadataStr, )
+ reader.getAndSetParams(lr, loadedMetadata)
+
class LDATest(SparkSessionTestCase):
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index a486c6a3fdeb5..9fa85664939b8 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -30,6 +30,7 @@
from pyspark import SparkContext, since
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession
+from pyspark.util import VersionUtils
def _jvm():
@@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None):
- sparkVersion
- uid
- paramMap
+ - defaultParamMap (since 2.4.0)
- (optionally, extra metadata)
:param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc.
:param paramMap: If given, this is saved in the "paramMap" field.
@@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None):
"""
uid = instance.uid
cls = instance.__module__ + '.' + instance.__class__.__name__
- params = instance.extractParamMap()
+
+ # User-supplied param values
+ params = instance._paramMap
jsonParams = {}
if paramMap is not None:
jsonParams = paramMap
else:
for p in params:
jsonParams[p.name] = params[p]
+
+ # Default param values
+ jsonDefaultParams = {}
+ for p in instance._defaultParamMap:
+ jsonDefaultParams[p.name] = instance._defaultParamMap[p]
+
basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)),
- "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams}
+ "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams,
+ "defaultParamMap": jsonDefaultParams}
if extraMetadata is not None:
basicMetadata.update(extraMetadata)
return json.dumps(basicMetadata, separators=[',', ':'])
@@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata):
"""
Extract Params from metadata, and set them in the instance.
"""
+ # Set user-supplied param values
for paramName in metadata['paramMap']:
param = instance.getParam(paramName)
paramValue = metadata['paramMap'][paramName]
instance.set(param, paramValue)
+ # Set default param values
+ majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion'])
+ major = majorAndMinorVersions[0]
+ minor = majorAndMinorVersions[1]
+
+ # For metadata file prior to Spark 2.4, there is no default section.
+ if major > 2 or (major == 2 and minor >= 4):
+ assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \
+ "`defaultParamMap` section not found"
+
+ for paramName in metadata['defaultParamMap']:
+ paramValue = metadata['defaultParamMap'][paramName]
+ instance._setDefault(**{paramName: paramValue})
+
@staticmethod
def loadParamsInstance(path, sc):
"""
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f5a584152b4f6..fbc8a2d038f8f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1108,8 +1108,11 @@ def add_months(start, months):
@since(1.5)
def months_between(date1, date2, roundOff=True):
"""
- Returns the number of months between date1 and date2.
- Unless `roundOff` is set to `False`, the result is rounded off to 8 digits.
+ Returns number of months between dates date1 and date2.
+ If date1 is later than date2, then the result is positive.
+ If date1 and date2 are on the same day of month, or both are the last day of month,
+ returns an integer (time of day will be ignored).
+ The result is rounded off to 8 digits unless `roundOff` is set to `False`.
>>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2'])
>>> df.select(months_between(df.date1, df.date2).alias('months')).collect()
@@ -1852,6 +1855,21 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
+@since(2.4)
+def arrays_overlap(a1, a2):
+ """
+ Collection function: returns true if the arrays contain any common non-null element; if not,
+ returns null if both the arrays are non-empty and any of them contains a null element; returns
+ false otherwise.
+
+ >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
+ >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
+ [Row(overlap=True), Row(overlap=False)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))
+
+
@since(2.4)
def slice(x, start, length):
"""
@@ -2092,12 +2110,13 @@ def json_tuple(col, *fields):
return Column(jc)
+@ignore_unicode_prefix
@since(2.1)
def from_json(col, schema, options={}):
"""
- Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType`
- of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an
- unparseable string.
+ Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType`
+ as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with
+ the specified schema. Returns `null`, in the case of an unparseable string.
:param col: string column in json format
:param schema: a StructType or ArrayType of StructType to use when parsing the json column.
@@ -2114,6 +2133,9 @@ def from_json(col, schema, options={}):
[Row(json=Row(a=1))]
>>> df.select(from_json(df.value, "a INT").alias("json")).collect()
[Row(json=Row(a=1))]
+ >>> schema = MapType(StringType(), IntegerType())
+ >>> df.select(from_json(df.value, schema).alias("json")).collect()
+ [Row(json={u'a': 1})]
>>> data = [(1, '''[{"a": 1}]''')]
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
>>> df = spark.createDataFrame(data, ("key", "value"))
@@ -2322,6 +2344,40 @@ def map_values(col):
return Column(sc._jvm.functions.map_values(_to_java_column(col)))
+@since(2.4)
+def map_entries(col):
+ """
+ Collection function: Returns an unordered array of all entries in the given map.
+
+ :param col: name of column or expression
+
+ >>> from pyspark.sql.functions import map_entries
+ >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data")
+ >>> df.select(map_entries("data").alias("entries")).show()
+ +----------------+
+ | entries|
+ +----------------+
+ |[[1, a], [2, b]]|
+ +----------------+
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
+
+
+@ignore_unicode_prefix
+@since(2.4)
+def array_repeat(col, count):
+ """
+ Collection function: creates an array containing a column repeated count times.
+
+ >>> df = spark.createDataFrame([('ab',)], ['data'])
+ >>> df.select(array_repeat(df.data, 3).alias('r')).collect()
+ [Row(r=[u'ab', u'ab', u'ab'])]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count))
+
+
# ---------------------------- User Defined Function ----------------------------------
class PandasUDFType(object):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 16aa9378ad8ee..c7bd8f01b907f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -4680,6 +4680,26 @@ def test_supported_types(self):
self.assertPandasEqual(expected2, result2)
self.assertPandasEqual(expected3, result3)
+ def test_array_type_correct(self):
+ from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
+
+ df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
+
+ output_schema = StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('arr', ArrayType(LongType()))])
+
+ udf = pandas_udf(
+ lambda pdf: pdf,
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result = df.groupby('id').apply(udf).sort('id').toPandas()
+ expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
+ self.assertPandasEqual(expected, result)
+
def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -5219,8 +5239,8 @@ def test_complex_groupby(self):
expected2 = df.groupby().agg(sum(df.v))
# groupby one column and one sql expression
- result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v))
- expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v))
+ result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
+ expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2)
# groupby one python UDF
result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8bb63fcc7ff9c..5d2e58bef6466 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -82,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type):
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
- raise TypeError("Return type of the user-defined functon should be "
+ raise TypeError("Return type of the user-defined function should be "
"Pandas.Series, but is {}".format(type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
index 022191d0070fd..91f64141e5318 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
@@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
Cannot find driver {driverId}
- return UIUtils.basicSparkPage(content, s"Details for Job $driverId")
+ return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId")
}
val driverState = state.get
val driverHeaders = Seq("Driver property", "Value")
@@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
retryHeaders, retryRow, Iterable.apply(driverState.description.retryState))
val content =
Driver state information for driver id {driverId}
- Back to Drivers
+ Back to Drivers
Driver state: {driverState.state}
@@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
;
- UIUtils.basicSparkPage(content, s"Details for Job $driverId")
+ UIUtils.basicSparkPage(request, content, s"Details for Job $driverId")
}
private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = {
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
index 88a6614d51384..c53285331ea68 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala
@@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage(
{retryTable}
;
- UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster")
+ UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster")
}
private def queuedRow(submission: MesosDriverDescription): Seq[Node] = {
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 8eda6cb1277c5..7250e58b6c49a 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -200,7 +200,29 @@ object YarnSparkHadoopUtil {
.map(new Path(_).getFileSystem(hadoopConf))
.getOrElse(FileSystem.get(hadoopConf))
- filesystemsToAccess + stagingFS
+ // Add the list of available namenodes for all namespaces in HDFS federation.
+ // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its
+ // namespaces.
+ val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") {
+ Set.empty
+ } else {
+ val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices")
+ // Retrieving the filesystem for the nameservices where HA is not enabled
+ val filesystemsWithoutHA = nameservices.flatMap { ns =>
+ Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode =>
+ new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf)
+ }
+ }
+ // Retrieving the filesystem for the nameservices where HA is enabled
+ val filesystemsWithHA = nameservices.flatMap { ns =>
+ Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ =>
+ new Path(s"hdfs://$ns").getFileSystem(hadoopConf)
+ }
+ }
+ (filesystemsWithoutHA ++ filesystemsWithHA).toSet
+ }
+
+ filesystemsToAccess ++ hadoopFilesystems + stagingFS
}
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
index f21353aa007c8..61c0c43f7c04f 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala
@@ -21,7 +21,8 @@ import java.io.{File, IOException}
import java.nio.charset.StandardCharsets
import com.google.common.io.{ByteStreams, Files}
-import org.apache.hadoop.io.Text
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.yarn.api.records.ApplicationAccessType
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.scalatest.Matchers
@@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging
}
+ test("SPARK-24149: retrieve all namenodes from HDFS") {
+ val sparkConf = new SparkConf()
+ val basicFederationConf = new Configuration()
+ basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020")
+ basicFederationConf.set("dfs.nameservices", "ns1,ns2")
+ basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020")
+ basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021")
+ val basicFederationExpected = Set(
+ new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf),
+ new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf))
+ val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(
+ sparkConf, basicFederationConf)
+ basicFederationResult should be (basicFederationExpected)
+
+ // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them
+ val viewFsConf = new Configuration()
+ viewFsConf.addResource(basicFederationConf)
+ viewFsConf.set("fs.defaultFS", "viewfs://clusterX/")
+ viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/")
+ val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf))
+ YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected)
+
+ // invalid config should not throw NullPointerException
+ val invalidFederationConf = new Configuration()
+ invalidFederationConf.addResource(basicFederationConf)
+ invalidFederationConf.unset("dfs.namenode.rpc-address.ns2")
+ val invalidFederationExpected = Set(
+ new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf))
+ val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(
+ sparkConf, invalidFederationConf)
+ invalidFederationResult should be (invalidFederationExpected)
+
+ // no namespaces defined, ie. old case
+ val noFederationConf = new Configuration()
+ noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020")
+ val noFederationExpected = Set(
+ new Path("hdfs://localhost:8020").getFileSystem(noFederationConf))
+ val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf)
+ noFederationResult should be (noFederationExpected)
+
+ // federation and HA enabled
+ val federationAndHAConf = new Configuration()
+ federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA")
+ federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA")
+ federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2")
+ federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022")
+ federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023")
+ federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA",
+ "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider")
+ federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA",
+ "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider")
+
+ val federationAndHAExpected = Set(
+ new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf),
+ new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf))
+ val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess(
+ sparkConf, federationAndHAConf)
+ federationAndHAResult should be (federationAndHAExpected)
+ }
}
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index f7f921ec22c35..7c54851097af3 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -398,7 +398,7 @@ hintStatement
;
fromClause
- : FROM relation (',' relation)* (pivotClause | lateralView*)?
+ : FROM relation (',' relation)* lateralView* pivotClause?
;
aggregation
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 29a1411241cf6..469b0e60cc9a2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -62,6 +62,8 @@
*/
public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable {
+ public static final int WORD_SIZE = 8;
+
//////////////////////////////////////////////////////////////////////////////
// Static methods
//////////////////////////////////////////////////////////////////////////////
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index ccdb6bc5d4b7c..7b02317b8538f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -68,10 +68,10 @@ import org.apache.spark.sql.types._
*/
@Experimental
@InterfaceStability.Evolving
-@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " +
- "(Int, String, etc) and Product types (case classes) are supported by importing " +
- "spark.implicits._ Support for serializing other types will be added in future " +
- "releases.")
+@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " +
+ "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " +
+ "classes) are supported by importing spark.implicits._ Support for serializing other types " +
+ "will be added in future releases.")
trait Encoder[T] extends Serializable {
/** Returns the schema of encoding this type of object as a Row. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dfdcdbc1eb2c7..3eaa9ecf5d075 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -676,13 +676,13 @@ class Analyzer(
try {
catalog.lookupRelation(tableIdentWithDb)
} catch {
- case _: NoSuchTableException =>
- u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}")
+ case e: NoSuchTableException =>
+ u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e)
// If the database is defined and that database is not found, throw an AnalysisException.
// Note that if the database is not defined, it is possible we are looking up a temp view.
case e: NoSuchDatabaseException =>
u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " +
- s"database ${e.db} doesn't exist.")
+ s"database ${e.db} doesn't exist.", e)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 94b0561529e71..90bda2a72ad82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
@@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper {
// Check if the data types match.
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
// SPARK-18058: we shall not care about the nullability of columns
- val widerType = TypeCoercion.findWiderTypeForTwo(
- dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis)
- if (widerType.isEmpty) {
+ if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) {
failAnalysis(
s"""
|${operator.nodeName} can only be performed on tables with the compatible
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f9fde53be22ae..23a4a440fac23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -410,6 +410,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
+ expression[ArraysOverlap]("arrays_overlap"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),
@@ -418,6 +419,7 @@ object FunctionRegistry {
expression[ElementAt]("element_at"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
+ expression[MapEntries]("map_entries"),
expression[Size]("size"),
expression[Slice]("slice"),
expression[Size]("cardinality"),
@@ -427,6 +429,7 @@ object FunctionRegistry {
expression[Reverse]("reverse"),
expression[Concat]("concat"),
expression[Flatten]("flatten"),
+ expression[ArrayRepeat]("array_repeat"),
CreateStruct.registryEntry,
// mask functions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index 4eb6e642b1c37..71ed75454cd4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
// For each column, traverse all the values and find a common data type and nullability.
val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
val inputTypes = column.map(_.dataType)
- val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion(
- inputTypes, conf.caseSensitiveAnalysis)
- val tpe = wideType.getOrElse {
+ val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
table.failAnalysis(s"incompatible types found in column $name for inline table")
}
StructField(name, tpe, nullable = column.exists(_.nullable))
@@ -105,7 +103,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
castedExpr.eval()
} catch {
case NonFatal(ex) =>
- table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
+ table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
}
})
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index a7ba201509b78..b2817b0538a7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -48,18 +48,18 @@ object TypeCoercion {
def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
InConversion(conf) ::
- WidenSetOperationTypes(conf) ::
+ WidenSetOperationTypes ::
PromoteStrings(conf) ::
DecimalPrecision ::
BooleanEquality ::
- FunctionArgumentConversion(conf) ::
+ FunctionArgumentConversion ::
ConcatCoercion(conf) ::
EltCoercion(conf) ::
- CaseWhenCoercion(conf) ::
- IfCoercion(conf) ::
+ CaseWhenCoercion ::
+ IfCoercion ::
StackCoercion ::
Division ::
- ImplicitTypeCasts(conf) ::
+ new ImplicitTypeCasts(conf) ::
DateTimeOperations ::
WindowFrameCoercion ::
Nil
@@ -83,10 +83,7 @@ object TypeCoercion {
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[DecimalPrecision]].
*/
- def findTightestCommonType(
- left: DataType,
- right: DataType,
- caseSensitive: Boolean): Option[DataType] = (left, right) match {
+ val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
@@ -105,32 +102,22 @@ object TypeCoercion {
case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) =>
Some(TimestampType)
- case (t1 @ StructType(fields1), t2 @ StructType(fields2)) =>
- val isSameType = if (caseSensitive) {
- DataType.equalsIgnoreNullability(t1, t2)
- } else {
- DataType.equalsIgnoreCaseAndNullability(t1, t2)
- }
-
- if (isSameType) {
- Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
- // Since t1 is same type of t2, two StructTypes have the same DataType
- // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
- // - Different names: use f1.name
- // - Different nullabilities: `nullable` is true iff one of them is nullable.
- val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get
- StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
- }))
- } else {
- None
- }
+ case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) =>
+ Some(StructType(fields1.zip(fields2).map { case (f1, f2) =>
+ // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType
+ // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`.
+ // - Different names: use f1.name
+ // - Different nullabilities: `nullable` is true iff one of them is nullable.
+ val dataType = findTightestCommonType(f1.dataType, f2.dataType).get
+ StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable)
+ }))
case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) =>
- findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2))
+ findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2))
case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) =>
- val keyType = findTightestCommonType(kt1, kt2, caseSensitive)
- val valueType = findTightestCommonType(vt1, vt2, caseSensitive)
+ val keyType = findTightestCommonType(kt1, kt2)
+ val valueType = findTightestCommonType(vt1, vt2)
Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2))
case _ => None
@@ -185,14 +172,13 @@ object TypeCoercion {
* i.e. the main difference with [[findTightestCommonType]] is that here we allow some
* loss of precision when widening decimal and double, and promotion to string.
*/
- def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = {
- findTightestCommonType(t1, t2, caseSensitive)
+ def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = {
+ findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse(stringPromotion(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
- findWiderTypeForTwo(et1, et2, caseSensitive)
- .map(ArrayType(_, containsNull1 || containsNull2))
+ findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}
@@ -207,8 +193,7 @@ object TypeCoercion {
case _ => false
}
- private def findWiderCommonType(
- types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = {
+ private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = {
// findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal
// to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType.
// Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance,
@@ -216,7 +201,7 @@ object TypeCoercion {
val (stringTypes, nonStringTypes) = types.partition(hasStringType(_))
(stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) =>
r match {
- case Some(d) => findWiderTypeForTwo(d, c, caseSensitive)
+ case Some(d) => findWiderTypeForTwo(d, c)
case _ => None
})
}
@@ -228,22 +213,20 @@ object TypeCoercion {
*/
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
t1: DataType,
- t2: DataType,
- caseSensitive: Boolean): Option[DataType] = {
- findTightestCommonType(t1, t2, caseSensitive)
+ t2: DataType): Option[DataType] = {
+ findTightestCommonType(t1, t2)
.orElse(findWiderTypeForDecimal(t1, t2))
.orElse((t1, t2) match {
case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) =>
- findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive)
+ findWiderTypeWithoutStringPromotionForTwo(et1, et2)
.map(ArrayType(_, containsNull1 || containsNull2))
case _ => None
})
}
- def findWiderTypeWithoutStringPromotion(
- types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = {
+ def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
- case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive)
+ case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c)
case None => None
})
}
@@ -296,32 +279,29 @@ object TypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
- case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] {
+ object WidenSetOperationTypes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ SetOperation(left, right) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
- val newChildren: Seq[LogicalPlan] =
- buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis)
+ val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
assert(newChildren.length == 2)
s.makeCopy(Array(newChildren.head, newChildren.last))
case s: Union if s.childrenResolved &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
- val newChildren: Seq[LogicalPlan] =
- buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis)
+ val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
s.makeCopy(Array(newChildren))
}
/** Build new children with the widest types for each attribute among all the children */
- private def buildNewChildrenWithWiderTypes(
- children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = {
+ private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))
// Get a sequence of data types, each of which is the widest type of this specific attribute
// in all the children
val targetTypes: Seq[DataType] =
- getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive)
+ getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]())
if (targetTypes.nonEmpty) {
// Add an extra Project if the targetTypes are different from the original types.
@@ -336,19 +316,18 @@ object TypeCoercion {
@tailrec private def getWidestTypes(
children: Seq[LogicalPlan],
attrIndex: Int,
- castedTypes: mutable.Queue[DataType],
- caseSensitive: Boolean): Seq[DataType] = {
+ castedTypes: mutable.Queue[DataType]): Seq[DataType] = {
// Return the result after the widen data types have been found for all the children
if (attrIndex >= children.head.output.length) return castedTypes.toSeq
// For the attrIndex-th attribute, find the widest type
- findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match {
+ findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
// If unable to find an appropriate widen type for this column, return an empty Seq
case None => Seq.empty[DataType]
// Otherwise, record the result in the queue and find the type for the next column
case Some(widenType) =>
castedTypes.enqueue(widenType)
- getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive)
+ getWidestTypes(children, attrIndex + 1, castedTypes)
}
}
@@ -453,7 +432,7 @@ object TypeCoercion {
val commonTypes = lhs.zip(rhs).flatMap { case (l, r) =>
findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf)
- .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis))
+ .orElse(findTightestCommonType(l.dataType, r.dataType))
}
// The number of columns/expressions must match between LHS and RHS of an
@@ -482,7 +461,7 @@ object TypeCoercion {
}
case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
- findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match {
+ findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
@@ -536,7 +515,7 @@ object TypeCoercion {
/**
* This ensure that the types for various functions are as expected.
*/
- case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule {
+ object FunctionArgumentConversion extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
@@ -544,7 +523,7 @@ object TypeCoercion {
case a @ CreateArray(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
@@ -552,7 +531,7 @@ object TypeCoercion {
case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case None => c
}
@@ -563,7 +542,7 @@ object TypeCoercion {
m.keys
} else {
val types = m.keys.map(_.dataType)
- findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case None => m.keys
}
@@ -573,7 +552,7 @@ object TypeCoercion {
m.values
} else {
val types = m.values.map(_.dataType)
- findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case None => m.values
}
@@ -601,7 +580,7 @@ object TypeCoercion {
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(es) =>
val types = es.map(_.dataType)
- findWiderCommonType(types, conf.caseSensitiveAnalysis) match {
+ findWiderCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
@@ -611,14 +590,14 @@ object TypeCoercion {
// string.g
case g @ Greatest(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match {
+ findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case None => g
}
case l @ Least(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
- findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match {
+ findWiderTypeWithoutStringPromotion(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case None => l
}
@@ -658,11 +637,11 @@ object TypeCoercion {
/**
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
- case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule {
+ object CaseWhenCoercion extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
- val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis)
+ val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
@@ -689,17 +668,16 @@ object TypeCoercion {
/**
* Coerces the type of different branches of If statement to a common type.
*/
- case class IfCoercion(conf: SQLConf) extends TypeCoercionRule {
+ object IfCoercion extends TypeCoercionRule {
override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
- findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map {
- widestType =>
- val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
- val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
- If(pred, newLeft, newRight)
+ findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
+ val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
+ val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
+ If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
case If(Literal(null, NullType), left, right) =>
If(Literal.create(null, BooleanType), left, right)
@@ -798,11 +776,12 @@ object TypeCoercion {
/**
* Casts types according to the expected input types for [[Expression]]s.
*/
- case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
+ class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule {
private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING)
- override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ override protected def coerceTypes(
+ plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
@@ -825,18 +804,17 @@ object TypeCoercion {
}
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
- findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map {
- commonType =>
- if (b.inputType.acceptsType(commonType)) {
- // If the expression accepts the tightest common type, cast to that.
- val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
- val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
- b.withNewChildren(Seq(newLeft, newRight))
- } else {
- // Otherwise, don't do anything with the expression.
- b
- }
- }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
+ findTightestCommonType(left.dataType, right.dataType).map { commonType =>
+ if (b.inputType.acceptsType(commonType)) {
+ // If the expression accepts the tightest common type, cast to that.
+ val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
+ val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
+ b.withNewChildren(Seq(newLeft, newRight))
+ } else {
+ // Otherwise, don't do anything with the expression.
+ b
+ }
+ }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index d3d6c636c4ba8..2bed41672fe33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 7731336d247db..354a3fa0602a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -41,6 +41,11 @@ package object analysis {
def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
}
+
+ /** Fails the analysis at the point where a specific tree node was parsed. */
+ def failAnalysis(msg: String, cause: Throwable): Nothing = {
+ throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause))
+ }
}
/** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 4cc84b27d9eb0..df3ab05e02c76 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
- s"""
+ code"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
} else {
- ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
+ ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 12330bfa55ab9..699ea53b5df0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
- ev.copy(code = eval.code +
- castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
+
+ ev.copy(code =
+ code"""
+ ${eval.code}
+ // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull}
+ ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)}
+ """)
}
// The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 97dff6ae88299..9b9fa41a47d0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -22,6 +22,7 @@ import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] {
JavaCode.isNullVariable(isNull),
JavaCode.variable(value, dataType)))
reduceCodeSize(ctx, eval)
- if (eval.code.nonEmpty) {
+ if (eval.code.toString.nonEmpty) {
// Add `this` in the comment.
- eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim)
+ eval.copy(code = ctx.registerComment(this.toString) + eval.code)
} else {
eval
}
@@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] {
private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = {
// TODO: support whole stage codegen too
- if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
+ if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) {
val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) {
val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull")
val localIsNull = eval.isNull
@@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] {
val funcFullName = ctx.addNewFunction(funcName,
s"""
|private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) {
- | ${eval.code.trim}
+ | ${eval.code}
| $setIsNull
| return ${eval.value};
|}
""".stripMargin)
eval.value = JavaCode.variable(newValue, dataType)
- eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
+ eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});"
}
}
@@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression {
if (nullable) {
val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
@@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression {
}
}
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${leftGen.code}
${rightGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression {
}
}
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
- ev.copy(code = s"""
- boolean ${ev.isNull} = false;
+ ev.copy(code = code"""
${leftGen.code}
${midGen.code}
${rightGen.code}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 9f0779642271d..f1da592a76845 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}
/**
@@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful {
ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
- ev.copy(code = s"""
+ ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
$countTerm++;""", isNull = FalseLiteral)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index e869258469a97..3e7ca88249737 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.DataType
/**
@@ -1030,7 +1031,7 @@ case class ScalaUDF(
""".stripMargin
ev.copy(code =
- s"""
+ code"""
|$evalCode
|${initArgs.mkString("\n")}
|$callFunc
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index ff7c98f714905..2ce9d072c71c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._
@@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
}
ev.copy(code = childCode.code +
- s"""
+ code"""
|long ${ev.value} = 0L;
|boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 787bcaf5e81de..9856b37e53fbc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
val idTerm = "partitionId"
ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm)
ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
+ ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;",
isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
index 6c4a3601c1730..84e38a8b2711e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -164,7 +165,7 @@ case class PreciseTimestampConversion(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
ev.copy(code = eval.code +
- s"""boolean ${ev.isNull} = ${eval.isNull};
+ code"""boolean ${ev.isNull} = ${eval.isNull};
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value};
""".stripMargin)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d4e322d23b95b..fe91e520169b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -220,30 +221,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.",
- examples = """
- Examples:
- > SELECT 3 _FUNC_ 2;
- 1.5
- > SELECT 2L _FUNC_ 2L;
- 1.0
- """)
-// scalastyle:on line.size.limit
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
-
- override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
+// Common base trait for Divide and Remainder, since these two classes are almost identical
+trait DivModLike extends BinaryArithmetic {
- override def symbol: String = "/"
- override def decimalMethod: String = "$div"
override def nullable: Boolean = true
- private lazy val div: (Any, Any) => Any = dataType match {
- case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
- }
-
- override def eval(input: InternalRow): Any = {
+ final override def eval(input: InternalRow): Any = {
val input2 = right.eval(input)
if (input2 == null || input2 == 0) {
null
@@ -252,13 +235,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
if (input1 == null) {
null
} else {
- div(input1, input2)
+ evalOperation(input1, input2)
}
}
}
+ def evalOperation(left: Any, right: Any): Any
+
/**
- * Special case handling due to division by 0 => null.
+ * Special case handling due to division/remainder by 0 => null.
*/
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval1 = left.genCode(ctx)
@@ -269,13 +254,13 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
s"${eval2.value} == 0"
}
val javaType = CodeGenerator.javaType(dataType)
- val divide = if (dataType.isInstanceOf[DecimalType]) {
+ val operation = if (dataType.isInstanceOf[DecimalType]) {
s"${eval1.value}.$decimalMethod(${eval2.value})"
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -283,10 +268,10 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
${ev.isNull} = true;
} else {
${eval1.code}
- ${ev.value} = $divide;
+ ${ev.value} = $operation;
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -297,13 +282,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
if (${eval1.isNull}) {
${ev.isNull} = true;
} else {
- ${ev.value} = $divide;
+ ${ev.value} = $operation;
}
}""")
}
}
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.",
+ examples = """
+ Examples:
+ > SELECT 3 _FUNC_ 2;
+ 1.5
+ > SELECT 2L _FUNC_ 2L;
+ 1.0
+ """)
+// scalastyle:on line.size.limit
+case class Divide(left: Expression, right: Expression) extends DivModLike {
+
+ override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
+
+ override def symbol: String = "/"
+ override def decimalMethod: String = "$div"
+
+ private lazy val div: (Any, Any) => Any = dataType match {
+ case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
+ }
+
+ override def evalOperation(left: Any, right: Any): Any = div(left, right)
+}
+
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.",
examples = """
@@ -313,82 +323,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
> SELECT MOD(2, 1.8);
0.2
""")
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Remainder(left: Expression, right: Expression) extends DivModLike {
override def inputType: AbstractDataType = NumericType
override def symbol: String = "%"
override def decimalMethod: String = "remainder"
- override def nullable: Boolean = true
- private lazy val integral = dataType match {
- case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
- case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
+ private lazy val mod: (Any, Any) => Any = dataType match {
+ // special cases to make float/double primitive types faster
+ case DoubleType =>
+ (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double]
+ case FloatType =>
+ (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float]
+
+ // catch-all cases
+ case i: IntegralType =>
+ val integral = i.integral.asInstanceOf[Integral[Any]]
+ (left, right) => integral.rem(left, right)
+ case i: FractionalType => // should only be DecimalType for now
+ val integral = i.asIntegral.asInstanceOf[Integral[Any]]
+ (left, right) => integral.rem(left, right)
}
- override def eval(input: InternalRow): Any = {
- val input2 = right.eval(input)
- if (input2 == null || input2 == 0) {
- null
- } else {
- val input1 = left.eval(input)
- if (input1 == null) {
- null
- } else {
- input1 match {
- case d: Double => d % input2.asInstanceOf[java.lang.Double]
- case f: Float => f % input2.asInstanceOf[java.lang.Float]
- case _ => integral.rem(input1, input2)
- }
- }
- }
- }
-
- /**
- * Special case handling for x % 0 ==> null.
- */
- override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val eval1 = left.genCode(ctx)
- val eval2 = right.genCode(ctx)
- val isZero = if (dataType.isInstanceOf[DecimalType]) {
- s"${eval2.value}.isZero()"
- } else {
- s"${eval2.value} == 0"
- }
- val javaType = CodeGenerator.javaType(dataType)
- val remainder = if (dataType.isInstanceOf[DecimalType]) {
- s"${eval1.value}.$decimalMethod(${eval2.value})"
- } else {
- s"($javaType)(${eval1.value} $symbol ${eval2.value})"
- }
- if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
- if ($isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- ${ev.value} = $remainder;
- }""")
- } else {
- ev.copy(code = s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
- ${ev.isNull} = true;
- } else {
- ${ev.value} = $remainder;
- }
- }""")
- }
- }
+ override def evalOperation(left: Any, right: Any): Any = mod(left, right)
}
@ExpressionDescription(
@@ -479,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
}
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -490,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
$result
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval2.code}
boolean ${ev.isNull} = false;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -612,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
@@ -687,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression {
""".stripMargin,
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|$codes
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4dda525294259..66315e5906253 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils}
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
* valid if `isNull` is set to `true`.
*/
-case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue)
+case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue)
object ExprCode {
def apply(isNull: ExprValue, value: ExprValue): ExprCode = {
- ExprCode(code = "", isNull, value)
+ ExprCode(code = EmptyBlock, isNull, value)
}
def forNullValue(dataType: DataType): ExprCode = {
- ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
+ ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType))
}
def forNonNullValue(value: ExprValue): ExprCode = {
- ExprCode(code = "", isNull = FalseLiteral, value = value)
+ ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value)
}
}
@@ -330,9 +331,9 @@ class CodegenContext {
def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = {
val value = addMutableState(javaType(dataType), variableName)
val code = dataType match {
- case StringType => s"$value = $initCode.clone();"
- case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();"
- case _ => s"$value = $initCode;"
+ case StringType => code"$value = $initCode.clone();"
+ case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();"
+ case _ => code"$value = $initCode;"
}
ExprCode(code, FalseLiteral, JavaCode.global(value, dataType))
}
@@ -764,6 +765,40 @@ class CodegenContext {
""".stripMargin
}
+ /**
+ * Generates code creating a [[UnsafeArrayData]]. The generated code executes
+ * a provided fallback when the size of backing array would exceed the array size limit.
+ * @param arrayName a name of the array to create
+ * @param numElements a piece of code representing the number of elements the array should contain
+ * @param elementSize a size of an element in bytes
+ * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]]
+ * and getting the backing array as a parameter
+ * @param fallbackCode a piece of code executed when the array size limit is exceeded
+ */
+ def createUnsafeArrayWithFallback(
+ arrayName: String,
+ numElements: String,
+ elementSize: Int,
+ bodyCode: String => String,
+ fallbackCode: String): String = {
+ val arraySize = freshName("size")
+ val arrayBytes = freshName("arrayBytes")
+ s"""
+ |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+ | $numElements,
+ | $elementSize);
+ |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+ | $fallbackCode
+ |} else {
+ | final byte[] $arrayBytes = new byte[(int)$arraySize];
+ | UnsafeArrayData $arrayName = new UnsafeArrayData();
+ | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements);
+ | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize);
+ | ${bodyCode(arrayBytes)}
+ |}
+ """.stripMargin
+ }
+
/**
* Generates code to do null safe execution, i.e. only execute the code when the input is not
* null by adding null check if necessary.
@@ -1022,7 +1057,7 @@ class CodegenContext {
val eval = expr.genCode(this)
val state = SubExprEliminationState(eval.isNull, eval.value)
e.foreach(localSubExprEliminationExprs.put(_, state))
- eval.code.trim
+ eval.code.toString
}
SubExprCodes(codes, localSubExprEliminationExprs.toMap)
}
@@ -1050,7 +1085,7 @@ class CodegenContext {
val fn =
s"""
|private void $fnName(InternalRow $INPUT_ROW) {
- | ${eval.code.trim}
+ | ${eval.code}
| $isNull = ${eval.isNull};
| $value = ${eval.value};
|}
@@ -1107,7 +1142,7 @@ class CodegenContext {
def registerComment(
text: => String,
placeholderId: String = "",
- force: Boolean = false): String = {
+ force: Boolean = false): Block = {
// By default, disable comments in generated code because computing the comments themselves can
// be extremely expensive in certain cases, such as deeply-nested expressions which operate over
// inputs with wide schemas. For more details on the performance issues that motivated this
@@ -1126,9 +1161,9 @@ class CodegenContext {
s"// $text"
}
placeHolderToComments += (name -> comment)
- s"/*$name*/"
+ code"/*$name*/"
} else {
- ""
+ EmptyBlock
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index a91989e129664..3f4704d287cbd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
/**
* A trait that can be used to provide a fallback mode for expression code generation.
@@ -46,7 +47,7 @@ trait CodegenFallback extends Expression {
val placeHolder = ctx.registerComment(this.toString)
val javaType = CodeGenerator.javaType(this.dataType)
if (nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
boolean ${ev.isNull} = $objectTerm == null;
@@ -55,7 +56,7 @@ trait CodegenFallback extends Expression {
${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
}""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
$placeHolder
Object $objectTerm = ((Expression) references[$idx]).eval($input);
$javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 01c350e9dbf69..39778661d1c48 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -22,6 +22,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
@@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values)
)
val code =
- s"""
+ code"""
|final InternalRow $tmpInput = $input;
|final Object[] $values = new Object[${schema.length}];
|$allFields
@@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
ctx,
JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType),
elementType)
- val code = s"""
+ val code = code"""
final ArrayData $tmpInput = $input;
final int $numElements = $tmpInput.numElements();
final Object[] $values = new Object[$numElements];
@@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType)
val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType)
- val code = s"""
+ val code = code"""
final MapData $tmpInput = $input;
${keyConverter.code}
${valueConverter.code}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 01b4d6c4529bd..8f2a5a0dce943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true)
val code =
- s"""
+ code"""
|$rowWriter.reset();
|$evalSubexpr
|$writeExpressions
@@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
| }
|
| public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) {
- | ${eval.code.trim}
+ | ${eval.code}
| return ${eval.value};
| }
|
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
index 74ff018488863..250ce48d059e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen
import java.lang.{Boolean => JBool}
+import scala.collection.mutable.ArrayBuffer
import scala.language.{existentials, implicitConversions}
import org.apache.spark.sql.types.{BooleanType, DataType}
@@ -114,6 +115,147 @@ object JavaCode {
}
}
+/**
+ * A trait representing a block of java code.
+ */
+trait Block extends JavaCode {
+
+ // The expressions to be evaluated inside this block.
+ def exprValues: Set[ExprValue]
+
+ // Returns java code string for this code block.
+ override def toString: String = _marginChar match {
+ case Some(c) => code.stripMargin(c).trim
+ case _ => code.trim
+ }
+
+ def length: Int = toString.length
+
+ def nonEmpty: Boolean = toString.nonEmpty
+
+ // The leading prefix that should be stripped from each line.
+ // By default we strip blanks or control characters followed by '|' from the line.
+ var _marginChar: Option[Char] = Some('|')
+
+ def stripMargin(c: Char): this.type = {
+ _marginChar = Some(c)
+ this
+ }
+
+ def stripMargin: this.type = {
+ _marginChar = Some('|')
+ this
+ }
+
+ // Concatenates this block with other block.
+ def + (other: Block): Block
+}
+
+object Block {
+
+ val CODE_BLOCK_BUFFER_LENGTH: Int = 512
+
+ implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks)
+
+ implicit class BlockHelper(val sc: StringContext) extends AnyVal {
+ def code(args: Any*): Block = {
+ sc.checkLengths(args)
+ if (sc.parts.length == 0) {
+ EmptyBlock
+ } else {
+ args.foreach {
+ case _: ExprValue =>
+ case _: Int | _: Long | _: Float | _: Double | _: String =>
+ case _: Block =>
+ case other => throw new IllegalArgumentException(
+ s"Can not interpolate ${other.getClass.getName} into code block.")
+ }
+
+ val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args)
+ CodeBlock(codeParts, blockInputs)
+ }
+ }
+ }
+
+ // Folds eagerly the literal args into the code parts.
+ private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = {
+ val codeParts = ArrayBuffer.empty[String]
+ val blockInputs = ArrayBuffer.empty[JavaCode]
+
+ val strings = parts.iterator
+ val inputs = args.iterator
+ val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
+
+ buf.append(strings.next)
+ while (strings.hasNext) {
+ val input = inputs.next
+ input match {
+ case _: ExprValue | _: Block =>
+ codeParts += buf.toString
+ buf.clear
+ blockInputs += input.asInstanceOf[JavaCode]
+ case _ =>
+ buf.append(input)
+ }
+ buf.append(strings.next)
+ }
+ if (buf.nonEmpty) {
+ codeParts += buf.toString
+ }
+
+ (codeParts.toSeq, blockInputs.toSeq)
+ }
+}
+
+/**
+ * A block of java code. Including a sequence of code parts and some inputs to this block.
+ * The actual java code is generated by embedding the inputs into the code parts.
+ */
+case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block {
+ override lazy val exprValues: Set[ExprValue] = {
+ blockInputs.flatMap {
+ case b: Block => b.exprValues
+ case e: ExprValue => Set(e)
+ }.toSet
+ }
+
+ override lazy val code: String = {
+ val strings = codeParts.iterator
+ val inputs = blockInputs.iterator
+ val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH)
+ buf.append(StringContext.treatEscapes(strings.next))
+ while (strings.hasNext) {
+ buf.append(inputs.next)
+ buf.append(StringContext.treatEscapes(strings.next))
+ }
+ buf.toString
+ }
+
+ override def + (other: Block): Block = other match {
+ case c: CodeBlock => Blocks(Seq(this, c))
+ case b: Blocks => Blocks(Seq(this) ++ b.blocks)
+ case EmptyBlock => this
+ }
+}
+
+case class Blocks(blocks: Seq[Block]) extends Block {
+ override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet
+ override lazy val code: String = blocks.map(_.toString).mkString("\n")
+
+ override def + (other: Block): Block = other match {
+ case c: CodeBlock => Blocks(blocks :+ c)
+ case b: Blocks => Blocks(blocks ++ b.blocks)
+ case EmptyBlock => this
+ }
+}
+
+object EmptyBlock extends Block with Serializable {
+ override val code: String = ""
+ override val exprValues: Set[ExprValue] = Set.empty
+
+ override def + (other: Block): Block = other
+}
+
/**
* A typed java fragment that must be a valid java expression.
*/
@@ -123,10 +265,9 @@ trait ExprValue extends JavaCode {
}
object ExprValue {
- implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString
+ implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code
}
-
/**
* A java expression fragment.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 12b9ab2b272ab..03b3b21a16617 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -18,15 +18,53 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
+import scala.collection.mutable
+
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
+/**
+ * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
+ * casting.
+ */
+trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
+ with ImplicitCastInputTypes {
+
+ @transient protected lazy val elementType: DataType =
+ inputTypes.head.asInstanceOf[ArrayType].elementType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
+ TypeCoercion.findTightestCommonType(e1, e2) match {
+ case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
+ case _ => Seq.empty
+ }
+ case _ => Seq.empty
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ (left.dataType, right.dataType) match {
+ case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
+ s"been two ${ArrayType.simpleString}s with same element type, but it's " +
+ s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
+ }
+ }
+}
+
+
/**
* Given an array or map, returns its size. Returns -1 if null.
*/
@@ -54,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = false;
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
@@ -118,6 +156,158 @@ case class MapValues(child: Expression)
override def prettyName: String = "map_values"
}
+/**
+ * Returns an unordered array of all entries in the given map.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(map(1, 'a', 2, 'b'));
+ [(1,"a"),(2,"b")]
+ """,
+ since = "2.4.0")
+case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+ lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
+
+ override def dataType: DataType = {
+ ArrayType(
+ StructType(
+ StructField("key", childDataType.keyType, false) ::
+ StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
+ Nil),
+ false)
+ }
+
+ override protected def nullSafeEval(input: Any): Any = {
+ val childMap = input.asInstanceOf[MapData]
+ val keys = childMap.keyArray()
+ val values = childMap.valueArray()
+ val length = childMap.numElements()
+ val resultData = new Array[AnyRef](length)
+ var i = 0;
+ while (i < length) {
+ val key = keys.get(i, childDataType.keyType)
+ val value = values.get(i, childDataType.valueType)
+ val row = new GenericInternalRow(Array[Any](key, value))
+ resultData.update(i, row)
+ i += 1
+ }
+ new GenericArrayData(resultData)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => {
+ val numElements = ctx.freshName("numElements")
+ val keys = ctx.freshName("keys")
+ val values = ctx.freshName("values")
+ val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
+ val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+ val code = if (isKeyPrimitive && isValuePrimitive) {
+ genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements)
+ } else {
+ genCodeForAnyElements(ctx, keys, values, ev.value, numElements)
+ }
+ s"""
+ |final int $numElements = $c.numElements();
+ |final ArrayData $keys = $c.keyArray();
+ |final ArrayData $values = $c.valueArray();
+ |$code
+ """.stripMargin
+ })
+ }
+
+ private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z")
+
+ private def getValue(varName: String) = {
+ CodeGenerator.getValue(varName, childDataType.valueType, "z")
+ }
+
+ private def genCodeForPrimitiveElements(
+ ctx: CodegenContext,
+ keys: String,
+ values: String,
+ arrayData: String,
+ numElements: String): String = {
+ val unsafeRow = ctx.freshName("unsafeRow")
+ val unsafeArrayData = ctx.freshName("unsafeArrayData")
+ val structsOffset = ctx.freshName("structsOffset")
+ val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
+
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+ val wordSize = UnsafeRow.WORD_SIZE
+ val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
+ val structSizeAsLong = structSize + "L"
+ val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+ val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType)
+
+ val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});"
+ val valueAssignmentChecked = if (childDataType.valueContainsNull) {
+ s"""
+ |if ($values.isNullAt(z)) {
+ | $unsafeRow.setNullAt(1);
+ |} else {
+ | $valueAssignment
+ |}
+ """.stripMargin
+ } else {
+ valueAssignment
+ }
+
+ val assignmentLoop = (byteArray: String) =>
+ s"""
+ |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
+ |UnsafeRow $unsafeRow = new UnsafeRow(2);
+ |for (int z = 0; z < $numElements; z++) {
+ | long offset = $structsOffset + z * $structSizeAsLong;
+ | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong);
+ | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize);
+ | $unsafeRow.set$keyTypeName(0, ${getKey(keys)});
+ | $valueAssignmentChecked
+ |}
+ |$arrayData = $unsafeArrayData;
+ """.stripMargin
+
+ ctx.createUnsafeArrayWithFallback(
+ unsafeArrayData,
+ numElements,
+ structSize + wordSize,
+ assignmentLoop,
+ genCodeForAnyElements(ctx, keys, values, arrayData, numElements))
+ }
+
+ private def genCodeForAnyElements(
+ ctx: CodegenContext,
+ keys: String,
+ values: String,
+ arrayData: String,
+ numElements: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val rowClass = classOf[GenericInternalRow].getName
+ val data = ctx.freshName("internalRowArray")
+
+ val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
+ val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
+ s"$values.isNullAt(z) ? null : (Object)${getValue(values)}"
+ } else {
+ getValue(values)
+ }
+
+ s"""
+ |final Object[] $data = new Object[$numElements];
+ |for (int z = 0; z < $numElements; z++) {
+ | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck});
+ |}
+ |$arrayData = new $genericArrayClass($data);
+ """.stripMargin
+ }
+
+ override def prettyName: String = "map_entries"
+}
+
/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
@@ -468,6 +658,9 @@ case class ArrayContains(left: Expression, right: Expression)
override def dataType: DataType = BooleanType
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(right.dataType)
+
override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq.empty
case _ => left.dataType match {
@@ -484,7 +677,7 @@ case class ArrayContains(left: Expression, right: Expression)
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
- TypeCheckResult.TypeCheckSuccess
+ TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
}
}
@@ -497,7 +690,7 @@ case class ArrayContains(left: Expression, right: Expression)
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
- } else if (v == value) {
+ } else if (ordering.equiv(v, value)) {
return true
}
)
@@ -529,6 +722,231 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
+/**
+ * Checks if the two arrays contain at least one common element.
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
+ true
+ """, since = "2.4.0")
+// scalastyle:off line.size.limit
+case class ArraysOverlap(left: Expression, right: Expression)
+ extends BinaryArrayExpressionWithImplicitCast {
+
+ override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
+ case failure => failure
+ }
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(elementType)
+
+ @transient private lazy val elementTypeSupportEquals = elementType match {
+ case BinaryType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+
+ @transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
+ fastEval _
+ } else {
+ bruteForceEval _
+ }
+
+ override def dataType: DataType = BooleanType
+
+ override def nullable: Boolean = {
+ left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
+ right.dataType.asInstanceOf[ArrayType].containsNull
+ }
+
+ override def nullSafeEval(a1: Any, a2: Any): Any = {
+ doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
+ }
+
+ /**
+ * A fast implementation which puts all the elements from the smaller array in a set
+ * and then performs a lookup on it for each element of the bigger one.
+ * This eval mode works only for data types which implements properly the equals method.
+ */
+ private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
+ var hasNull = false
+ val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) {
+ (arr1, arr2)
+ } else {
+ (arr2, arr1)
+ }
+ if (smaller.numElements() > 0) {
+ val smallestSet = new mutable.HashSet[Any]
+ smaller.foreach(elementType, (_, v) =>
+ if (v == null) {
+ hasNull = true
+ } else {
+ smallestSet += v
+ })
+ bigger.foreach(elementType, (_, v1) =>
+ if (v1 == null) {
+ hasNull = true
+ } else if (smallestSet.contains(v1)) {
+ return true
+ }
+ )
+ }
+ if (hasNull) {
+ null
+ } else {
+ false
+ }
+ }
+
+ /**
+ * A slower evaluation which performs a nested loop and supports all the data types.
+ */
+ private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
+ var hasNull = false
+ if (arr1.numElements() > 0 && arr2.numElements() > 0) {
+ arr1.foreach(elementType, (_, v1) =>
+ if (v1 == null) {
+ hasNull = true
+ } else {
+ arr2.foreach(elementType, (_, v2) =>
+ if (v2 == null) {
+ hasNull = true
+ } else if (ordering.equiv(v1, v2)) {
+ return true
+ }
+ )
+ })
+ }
+ if (hasNull) {
+ null
+ } else {
+ false
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, (a1, a2) => {
+ val smaller = ctx.freshName("smallerArray")
+ val bigger = ctx.freshName("biggerArray")
+ val comparisonCode = if (elementTypeSupportEquals) {
+ fastCodegen(ctx, ev, smaller, bigger)
+ } else {
+ bruteForceCodegen(ctx, ev, smaller, bigger)
+ }
+ s"""
+ |ArrayData $smaller;
+ |ArrayData $bigger;
+ |if ($a1.numElements() > $a2.numElements()) {
+ | $bigger = $a1;
+ | $smaller = $a2;
+ |} else {
+ | $smaller = $a1;
+ | $bigger = $a2;
+ |}
+ |if ($smaller.numElements() > 0) {
+ | $comparisonCode
+ |}
+ """.stripMargin
+ })
+ }
+
+ /**
+ * Code generation for a fast implementation which puts all the elements from the smaller array
+ * in a set and then performs a lookup on it for each element of the bigger one.
+ * It works only for data types which implements properly the equals method.
+ */
+ private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
+ val i = ctx.freshName("i")
+ val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
+ val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
+ val javaElementClass = CodeGenerator.boxedType(elementType)
+ val javaSet = classOf[java.util.HashSet[_]].getName
+ val set = ctx.freshName("set")
+ val addToSetFromSmallerCode = nullSafeElementCodegen(
+ smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
+ val elementIsInSetCode = nullSafeElementCodegen(
+ bigger,
+ i,
+ s"""
+ |if ($set.contains($getFromBigger)) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = true;
+ | break;
+ |}
+ """.stripMargin,
+ s"${ev.isNull} = true;")
+ s"""
+ |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
+ |for (int $i = 0; $i < $smaller.numElements(); $i ++) {
+ | $addToSetFromSmallerCode
+ |}
+ |for (int $i = 0; $i < $bigger.numElements(); $i ++) {
+ | $elementIsInSetCode
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Code generation for a slower evaluation which performs a nested loop and supports all the data types.
+ */
+ private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
+ val i = ctx.freshName("i")
+ val j = ctx.freshName("j")
+ val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
+ val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
+ val compareValues = nullSafeElementCodegen(
+ smaller,
+ j,
+ s"""
+ |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
+ | ${ev.isNull} = false;
+ | ${ev.value} = true;
+ |}
+ """.stripMargin,
+ s"${ev.isNull} = true;")
+ val isInSmaller = nullSafeElementCodegen(
+ bigger,
+ i,
+ s"""
+ |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
+ | $compareValues
+ |}
+ """.stripMargin,
+ s"${ev.isNull} = true;")
+ s"""
+ |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
+ | $isInSmaller
+ |}
+ """.stripMargin
+ }
+
+ def nullSafeElementCodegen(
+ arrayVar: String,
+ index: String,
+ code: String,
+ isNullCode: String): String = {
+ if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
+ s"""
+ |if ($arrayVar.isNullAt($index)) {
+ | $isNullCode
+ |} else {
+ | $code
+ |}
+ """.stripMargin
+ } else {
+ code
+ }
+ }
+
+ override def prettyName: String = "arrays_overlap"
+}
+
/**
* Slices an array according to the requested start index and length
*/
@@ -760,14 +1178,14 @@ case class ArrayJoin(
}
if (nullable) {
ev.copy(
- s"""
+ code"""
|boolean ${ev.isNull} = true;
|UTF8String ${ev.value} = null;
|$code
""".stripMargin)
} else {
ev.copy(
- s"""
+ code"""
|UTF8String ${ev.value} = null;
|$code
""".stripMargin, FalseLiteral)
@@ -852,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
- val item = ExprCode("",
+ val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -917,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
- val item = ExprCode("",
+ val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -973,13 +1391,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
case class ArrayPosition(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(right.dataType)
+
override def dataType: DataType = LongType
override def inputTypes: Seq[AbstractDataType] =
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
+ }
+ }
+
override def nullSafeEval(arr: Any, value: Any): Any = {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
- if (v == value) {
+ if (v != null && ordering.equiv(v, value)) {
return (i + 1).toLong
}
)
@@ -1028,6 +1457,9 @@ case class ArrayPosition(left: Expression, right: Expression)
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
+
override def dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
@@ -1038,10 +1470,21 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
left.dataType match {
case _: ArrayType => IntegerType
case _: MapType => left.dataType.asInstanceOf[MapType].keyType
+ case _ => AnyDataType // no match for a wrong 'left' expression type
}
)
}
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
+ TypeUtils.checkForOrderingExpr(
+ left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
+ case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
override def nullable: Boolean = true
override def nullSafeEval(value: Any, ordinal: Any): Any = {
@@ -1066,7 +1509,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
}
}
case _: MapType =>
- getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
+ getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
}
}
@@ -1212,7 +1655,7 @@ case class Concat(children: Seq[Expression]) extends Expression {
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
- ev.copy(s"""
+ ev.copy(code"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
@@ -1468,3 +1911,152 @@ case class Flatten(child: Expression) extends UnaryExpression {
override def prettyName: String = "flatten"
}
+
+/**
+ * Returns the array containing the given input value (left) count (right) times.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(element, count) - Returns the array containing element count times.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('123', 2);
+ ['123', '123']
+ """,
+ since = "2.4.0")
+case class ArrayRepeat(left: Expression, right: Expression)
+ extends BinaryExpression with ExpectsInputTypes {
+
+ private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
+
+ override def nullable: Boolean = right.nullable
+
+ override def eval(input: InternalRow): Any = {
+ val count = right.eval(input)
+ if (count == null) {
+ null
+ } else {
+ if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
+ s"due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ }
+ val element = left.eval(input)
+ new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
+ }
+ }
+
+ override def prettyName: String = "array_repeat"
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val leftGen = left.genCode(ctx)
+ val rightGen = right.genCode(ctx)
+ val element = leftGen.value
+ val count = rightGen.value
+ val et = dataType.elementType
+
+ val coreLogic = if (CodeGenerator.isPrimitiveType(et)) {
+ genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value)
+ } else {
+ genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value)
+ }
+ val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
+
+ ev.copy(code =
+ code"""
+ |boolean ${ev.isNull} = false;
+ |${leftGen.code}
+ |${rightGen.code}
+ |${CodeGenerator.javaType(dataType)} ${ev.value} =
+ | ${CodeGenerator.defaultValue(dataType)};
+ |$resultCode
+ """.stripMargin)
+ }
+
+ private def nullElementsProtection(
+ ev: ExprCode,
+ rightIsNull: String,
+ coreLogic: String): String = {
+ if (nullable) {
+ s"""
+ |if ($rightIsNull) {
+ | ${ev.isNull} = true;
+ |} else {
+ | ${coreLogic}
+ |}
+ """.stripMargin
+ } else {
+ coreLogic
+ }
+ }
+
+ private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = {
+ val numElements = ctx.freshName("numElements")
+ val numElementsCode =
+ s"""
+ |int $numElements = 0;
+ |if ($count > 0) {
+ | $numElements = $count;
+ |}
+ |if ($numElements > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to create array with " + $numElements +
+ | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+
+ (numElements, numElementsCode)
+ }
+
+ private def genCodeForPrimitiveElement(
+ ctx: CodegenContext,
+ elementType: DataType,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val tempArrayDataName = ctx.freshName("tempArrayData")
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+ val errorMessage = s" $prettyName failed."
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)}
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.set$primitiveValueTypeName(k, $element);
+ | }
+ |} else {
+ | for (int k = 0; k < $tempArrayDataName.numElements(); k++) {
+ | $tempArrayDataName.setNullAt(k);
+ | }
+ |}
+ |$arrayDataName = $tempArrayDataName;
+ """.stripMargin
+ }
+
+ private def genCodeForNonPrimitiveElement(
+ ctx: CodegenContext,
+ element: String,
+ count: String,
+ leftIsNull: String,
+ arrayDataName: String): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayName = ctx.freshName("arrayObject")
+ val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
+
+ s"""
+ |$numElemCode
+ |Object[] $arrayName = new Object[(int)$numElemName];
+ |if (!$leftIsNull) {
+ | for (int k = 0; k < $numElemName; k++) {
+ | $arrayName[k] = $element;
+ | }
+ |}
+ |$arrayDataName = new $genericArrayClass($arrayName);
+ """.stripMargin
+ }
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 67876a8565488..a9867aaeb0cfe 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
val (preprocess, assigns, postprocess, arrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
ev.copy(
- code = preprocess + assigns + postprocess,
+ code = code"${preprocess}${assigns}${postprocess}",
value = JavaCode.variable(arrayData, dataType),
isNull = FalseLiteral)
}
@@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) =
GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false)
val code =
- s"""
+ code"""
final boolean ${ev.isNull} = false;
$preprocessKeyData
$assignKeys
@@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc
extraArguments = "Object[]" -> values :: Nil)
ev.copy(code =
- s"""
+ code"""
|Object[] $values = new Object[${valExprs.size}];
|$valuesCode
|final InternalRow ${ev.value} = new $rowClass($values);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 3fba52d745453..99671d5b863c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
-import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
// todo: current search is O(n), improve it.
- def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
+ def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
@@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
var i = 0
var found = false
while (i < length && !found) {
- if (keys.get(i, keyType) == ordinal) {
+ if (ordering.equiv(keys.get(i, keyType), ordinal)) {
found = true
} else {
i += 1
@@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
case class GetMapValue(child: Expression, key: Expression)
extends GetMapValueUtil with ExtractValue with NullIntolerant {
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(keyType)
+
private def keyType = child.dataType.asInstanceOf[MapType].keyType
+ override def checkInputDataTypes(): TypeCheckResult = {
+ super.checkInputDataTypes() match {
+ case f: TypeCheckResult.TypeCheckFailure => f
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName")
+ }
+ }
+
// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
@@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression)
// todo: current search is O(n), improve it.
override def nullSafeEval(value: Any, ordinal: Any): Any = {
- getValueEval(value, ordinal, keyType)
+ getValueEval(value, ordinal, keyType, ordering)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 205d77f6a9acf..77ac6c088022e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
// scalastyle:off line.size.limit
@@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
val falseEval = falseValue.genCode(ctx)
val code =
- s"""
+ code"""
|${condEval.code}
|boolean ${ev.isNull} = false;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -265,7 +266,7 @@ case class CaseWhen(
}.mkString)
ev.copy(code =
- s"""
+ code"""
|${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED;
|do {
| $codes
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 76aa61415a11f..e8d85f72f7a7a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -717,7 +718,7 @@ abstract class UnixTime
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -746,7 +747,7 @@ abstract class UnixTime
})
case TimestampType =>
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -757,7 +758,7 @@ abstract class UnixTime
val tz = ctx.addReferenceObj("timeZone", timeZone)
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
val eval1 = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = ${eval1.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
} else {
val formatterName = ctx.addReferenceObj("formatter", formatter, df)
val t = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio
val tz = ctx.addReferenceObj("timeZone", timeZone)
val longOpt = ctx.freshName("longOpt")
val eval = child.genCode(ctx)
- val code = s"""
+ val code = code"""
|${eval.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true;
|${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)};
@@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@@ -1194,13 +1195,21 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
}
/**
- * Returns number of months between dates date1 and date2.
+ * Returns number of months between times `timestamp1` and `timestamp2`.
+ * If `timestamp1` is later than `timestamp2`, then the result is positive.
+ * If `timestamp1` and `timestamp2` are on the same day of month, or both
+ * are the last day of month, time of day will be ignored. Otherwise, the
+ * difference is calculated based on 31 days per month, and rounded to
+ * 8 digits unless roundOff=false.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`.
- The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""",
+ _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result
+ is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both
+ are the last day of month, time of day will be ignored. Otherwise, the difference is
+ calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false.
+ """,
examples = """
Examples:
> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');
@@ -1279,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
if (right.foldable) {
val tz = right.eval().asInstanceOf[UTF8String]
if (tz == null) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
|boolean ${ev.isNull} = true;
|long ${ev.value} = 0;
""".stripMargin)
@@ -1293,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
v => s"""$v = $dtu.getTimeZone("UTC");""")
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull};
|long ${ev.value} = 0;
@@ -1436,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes {
val javaType = CodeGenerator.javaType(dataType)
if (format.foldable) {
if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
val t = instant.genCode(ctx)
val truncFuncStr = truncFunc(t.value, truncLevel.toString)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${t.code}
boolean ${ev.isNull} = ${t.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index db1579ba28671..04de83343be71 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode}
import org.apache.spark.sql.types._
/**
@@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
override def eval(input: InternalRow): Any = child.eval(input)
/** Just a simple pass-through for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("")
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ ev.copy(EmptyBlock)
override def prettyName: String = "promote_precision"
override def sql: String = child.sql
override lazy val canonicalized: Expression = child.canonicalized
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 3af4bfebad45e..b7c52f1d7b40a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
@@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator {
// Create the collection.
val wrapperClass = classOf[mutable.WrappedArray[_]].getName
ev.copy(code =
- s"""
+ code"""
|$code
|$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData);
""".stripMargin, isNull = FalseLiteral)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index ef790338bdd27..cec00b66f873c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
@@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression {
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|$hashResultType ${ev.value} = $seed;
|$codes
""".stripMargin)
@@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
ev.copy(code =
- s"""
+ code"""
|${CodeGenerator.JAVA_INT} ${ev.value} = $seed;
|${CodeGenerator.JAVA_INT} $childHash = 0;
|$codes
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
index 2a3cc580273ee..3b0141ad52cc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.rdd.InputFileBlockHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
- s"$className.getInputFilePath();", isNull = FalseLiteral)
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();",
+ isNull = FalseLiteral)
}
}
@@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
- s"$className.getStartOffset();", isNull = FalseLiteral)
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral)
}
}
@@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
- ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " +
- s"$className.getLength();", isNull = FalseLiteral)
+ val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
+ ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 34161f0f03f4a..04a4eb0ffc032 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -548,7 +548,7 @@ case class JsonToStructs(
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
- case _: StructType | ArrayType(_: StructType, _) =>
+ case _: StructType | ArrayType(_: StructType, _) | _: MapType =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.")
@@ -558,6 +558,7 @@ case class JsonToStructs(
lazy val rowSchema = nullableSchema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
+ case mt: MapType => mt
}
// This converts parsed rows to the desired output by the given schema.
@@ -567,6 +568,8 @@ case class JsonToStructs(
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
case ArrayType(_: StructType, _) =>
(rows: Seq[InternalRow]) => new GenericArrayData(rows)
+ case _: MapType =>
+ (rows: Seq[InternalRow]) => rows.head.getMap(0)
}
@transient
@@ -613,6 +616,11 @@ case class JsonToStructs(
}
override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+
+ override def sql: String = schema match {
+ case _: MapType => "entries"
+ case _ => super.sql
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index bc4cfcec47425..c2e1720259b53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.NumberConverter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression,
val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index b7834696cafc3..5d98dac46cf17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -21,6 +21,7 @@ import java.util.UUID
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
// Use unnamed reference that doesn't create a local field here to reduce the number of fields
// because errMsgField is used only when the value is null or false.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
- ExprCode(code = s"""${eval.code}
+ ExprCode(code = code"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
| throw new RuntimeException($errMsgField);
|}""".stripMargin, isNull = TrueLiteral,
@@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta
ctx.addPartitionInitializationStatement(s"$randomGen = " +
"new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" +
s"${randomSeed.get}L + partitionIndex);")
- ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
+ ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();",
isNull = FalseLiteral)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 0787342bce6bc..2eeed3bbb2d91 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
ev.copy(code =
- s"""
+ code"""
|${ev.isNull} = true;
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|do {
@@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression
val eval = child.genCode(ctx)
child.dataType match {
case DoubleType | FloatType =>
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral)
@@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression)
val rightGen = right.genCode(ctx)
left.dataType match {
case DoubleType | FloatType =>
- ev.copy(code = s"""
+ ev.copy(code = code"""
${leftGen.code}
boolean ${ev.isNull} = false;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate
}.mkString)
ev.copy(code =
- s"""
+ code"""
|${CodeGenerator.JAVA_INT} $nonnull = 0;
|do {
| $codes
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index f974fd81fc788..2bf4203d0fec3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -269,7 +270,7 @@ case class StaticInvoke(
s"${ev.value} = $callFunc;"
}
- val code = s"""
+ val code = code"""
$argCode
$prepareIsNull
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -385,8 +386,7 @@ case class Invoke(
"""
}
- val code = s"""
- ${obj.code}
+ val code = obj.code + code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${obj.isNull}) {
@@ -492,7 +492,7 @@ case class NewInstance(
s"new $className($argString)"
}
- val code = s"""
+ val code = code"""
$argCode
${outer.map(_.code).getOrElse("")}
final $javaType ${ev.value} = ${ev.isNull} ?
@@ -532,9 +532,7 @@ case class UnwrapOption(
val javaType = CodeGenerator.javaType(dataType)
val inputObject = child.genCode(ctx)
- val code = s"""
- ${inputObject.code}
-
+ val code = inputObject.code + code"""
final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty();
$javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} :
(${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get();
@@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val inputObject = child.genCode(ctx)
- val code = s"""
- ${inputObject.code}
-
+ val code = inputObject.code + code"""
scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
@@ -935,8 +931,7 @@ case class MapObjects private(
)
}
- val code = s"""
- ${genInputData.code}
+ val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
@@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private(
"""
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
- val code = s"""
- ${genInputData.code}
+ val code = genInputData.code + code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${genInputData.isNull}) {
@@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private(
val mapCls = classOf[ArrayBasedMapData].getName
val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType)
val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType)
- val code =
- s"""
- ${inputMap.code}
+ val code = inputMap.code +
+ code"""
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${inputMap.isNull}) {
final int $length = ${inputMap.value}.size();
@@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
val schemaField = ctx.addReferenceObj("schema", schema)
val code =
- s"""
+ code"""
|Object[] $values = new Object[${children.size}];
|$childrenCode
|final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
@@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean)
val javaType = CodeGenerator.javaType(dataType)
val serialize = s"$serializer.serialize(${input.value}, null).array()"
- val code = s"""
- ${input.code}
+ val code = input.code + code"""
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;
"""
@@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B
val deserialize =
s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
- val code = s"""
- ${input.code}
+ val code = input.code + code"""
final $javaType ${ev.value} =
${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;
"""
@@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
funcName = "initializeJavaBean",
extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil)
- val code =
- s"""
- |${instanceGen.code}
+ val code = instanceGen.code +
+ code"""
|$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value};
|if (!${instanceGen.isNull}) {
| $initializeCode
@@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil)
// because errMsgField is used only when the value is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
- val code = s"""
- ${childGen.code}
-
+ val code = childGen.code + code"""
if (${childGen.isNull}) {
throw new NullPointerException($errMsgField);
}
@@ -1709,7 +1697,7 @@ case class GetExternalRowField(
// because errMsgField is used only when the field is null.
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
- val code = s"""
+ val code = code"""
${row.code}
if (${row.isNull}) {
@@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
s"$obj instanceof ${CodeGenerator.boxedType(dataType)}"
}
- val code = s"""
+ val code = code"""
${input.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${input.isNull}) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index f8c6dc4e6adc9..f54103c4fbfba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
}.mkString("\n"))
ev.copy(code =
- s"""
+ code"""
|${valueGen.code}
|byte $tmpResult = $HAS_NULL;
|if (!${valueGen.isNull}) {
@@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
""
}
ev.copy(code =
- s"""
+ code"""
|${childGen.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
@@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
// The result should be `false`, if any of them is `false` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = false;
@@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
${ev.value} = ${eval2.value};
}""", isNull = FalseLiteral)
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = false;
@@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
// The result should be `true`, if any of them is `true` whenever the other is null or not.
if (!left.nullable && !right.nullable) {
ev.isNull = FalseLiteral
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.value} = true;
@@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
${ev.value} = ${eval2.value};
}""", isNull = FalseLiteral)
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval1.code}
boolean ${ev.isNull} = false;
boolean ${ev.value} = true;
@@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
val eval1 = left.genCode(ctx)
val eval2 = right.genCode(ctx)
val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value)
- ev.copy(code = eval1.code + eval2.code + s"""
+ ev.copy(code = eval1.code + eval2.code + code"""
boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) ||
(!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 2653b28f6c3bd..926c2f00d430d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
@@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
- ev.copy(code = s"""
+ ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
isNull = FalseLiteral)
}
@@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG {
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
- ev.copy(code = s"""
+ ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
isNull = FalseLiteral)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index ad0c0791d895f..7b68bb771faf3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
}
""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
@@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
@@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
}
""")
} else {
- ev.copy(code = s"""
+ ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
""")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index ea005a26a4c8b..9823b2fc5ad97 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression])
expressions = inputs,
funcName = "valueConcatWs",
extraArguments = ("UTF8String[]", args) :: Nil)
- ev.copy(s"""
+ ev.copy(code"""
UTF8String[] $args = new UTF8String[$numArgs];
${separator.code}
$codes
@@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression])
}
}.unzip
- val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
+ val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString))
val varargCounts = ctx.splitExpressionsWithCurrentInputs(
expressions = varargCount,
@@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression])
foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n"))
ev.copy(
- s"""
+ code"""
$codes
int $varargNum = ${children.count(_.dataType == StringType) - 1};
int $idxVararg = 0;
@@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression {
}.mkString)
ev.copy(
- s"""
+ code"""
|${index.code}
|final int $indexVal = ${index.value};
|${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false;
@@ -654,7 +655,7 @@ case class StringTrim(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -671,7 +672,7 @@ case class StringTrim(
} else {
${ev.value} = ${srcString.value}.trim(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -754,7 +755,7 @@ case class StringTrimLeft(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -771,7 +772,7 @@ case class StringTrimLeft(
} else {
${ev.value} = ${srcString.value}.trimLeft(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -856,7 +857,7 @@ case class StringTrimRight(
val srcString = evals(0)
if (evals.length == 1) {
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -873,7 +874,7 @@ case class StringTrimRight(
} else {
${ev.value} = ${srcString.value}.trimRight(${trimString.value});
}"""
- ev.copy(evals.map(_.code).mkString + s"""
+ ev.copy(evals.map(_.code) :+ code"""
boolean ${ev.isNull} = false;
UTF8String ${ev.value} = null;
if (${srcString.isNull}) {
@@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
val substrGen = substr.genCode(ctx)
val strGen = str.genCode(ctx)
val startGen = start.genCode(ctx)
- ev.copy(code = s"""
+ ev.copy(code = code"""
int ${ev.value} = 0;
boolean ${ev.isNull} = false;
${startGen.code}
@@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
val formatter = classOf[java.util.Formatter].getName
val sb = ctx.freshName("sb")
val stringBuffer = classOf[StringBuffer].getName
- ev.copy(code = s"""
+ ev.copy(code = code"""
${pattern.code}
boolean ${ev.isNull} = ${pattern.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index a5a4a13eb608b..c3a4ca8f64bf6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -36,7 +36,7 @@ import org.apache.spark.util.Utils
* Constructs a parser for a given schema that translates a json string to an [[InternalRow]].
*/
class JacksonParser(
- schema: StructType,
+ schema: DataType,
val options: JSONOptions) extends Logging {
import JacksonUtils._
@@ -57,7 +57,14 @@ class JacksonParser(
* to a value according to a desired schema. This is a wrapper for the method
* `makeConverter()` to handle a row wrapped with an array.
*/
- private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
+ private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = {
+ dt match {
+ case st: StructType => makeStructRootConverter(st)
+ case mt: MapType => makeMapRootConverter(mt)
+ }
+ }
+
+ private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = {
val elementConverter = makeConverter(st)
val fieldConverters = st.map(_.dataType).map(makeConverter).toArray
(parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) {
@@ -87,6 +94,13 @@ class JacksonParser(
}
}
+ private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = {
+ val fieldConverter = makeConverter(mt.valueType)
+ (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) {
+ case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter)))
+ }
+ }
+
/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 64eed23884584..b9ece295c2510 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -504,6 +504,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
withJoinRelations(join, relation)
}
if (ctx.pivotClause() != null) {
+ if (!ctx.lateralView.isEmpty) {
+ throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx)
+ }
withPivot(ctx.pivotClause, from)
} else {
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index e646da0659e85..80f15053005ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -885,13 +885,13 @@ object DateTimeUtils {
/**
* Returns number of months between time1 and time2. time1 and time2 are expressed in
- * microseconds since 1.1.1970.
+ * microseconds since 1.1.1970. If time1 is later than time2, the result is positive.
*
- * If time1 and time2 having the same day of month, or both are the last day of month,
- * it returns an integer (time under a day will be ignored).
+ * If time1 and time2 are on the same day of month, or both are the last day of month,
+ * returns, time of day will be ignored.
*
* Otherwise, the difference is calculated based on 31 days per month.
- * If `roundOff` is set to true, the result is rounded to 8 decimal places.
+ * The result is rounded to 8 decimal places if `roundOff` is set to true.
*/
def monthsBetween(
time1: SQLTimestamp,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala
new file mode 100644
index 0000000000000..19f67236c8979
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal
+
+import java.util.{Map => JMap}
+
+import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader}
+
+/**
+ * A readonly SQLConf that will be created by tasks running at the executor side. It reads the
+ * configs from the local properties which are propagated from driver to executors.
+ */
+class ReadOnlySQLConf(context: TaskContext) extends SQLConf {
+
+ @transient override val settings: JMap[String, String] = {
+ context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]]
+ }
+
+ @transient override protected val reader: ConfigReader = {
+ new ConfigReader(new TaskContextConfigProvider(context))
+ }
+
+ override protected def setConfWithCheck(key: String, value: String): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def unsetConf(key: String): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def unsetConf(entry: ConfigEntry[_]): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def clear(): Unit = {
+ throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.")
+ }
+
+ override def clone(): SQLConf = {
+ throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
+ }
+
+ override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = {
+ throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.")
+ }
+}
+
+class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider {
+ override def get(key: String): Option[String] = Option(context.getLocalProperty(key))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0b1965c438e27..d0478d6ad250b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -27,7 +27,7 @@ import scala.util.matching.Regex
import org.apache.hadoop.fs.Path
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
@@ -95,7 +95,9 @@ object SQLConf {
/**
* Returns the active config object within the current scope. If there is an active SparkSession,
- * the proper SQLConf associated with the thread's session is used.
+ * the proper SQLConf associated with the thread's active session is used. If it's called from
+ * tasks in the executor side, a SQLConf will be created from job local properties, which are set
+ * and propagated from the driver side.
*
* The way this works is a little bit convoluted, due to the fact that config was added initially
* only for physical plans (and as a result not in sql/catalyst module).
@@ -108,11 +110,20 @@ object SQLConf {
* run unit tests (that does not involve SparkSession) in serial order.
*/
def get: SQLConf = {
- if (Utils.isTesting && TaskContext.get != null) {
- // we're accessing it during task execution, fail.
- throw new IllegalStateException("SQLConf should only be created and accessed on the driver.")
+ if (TaskContext.get != null) {
+ new ReadOnlySQLConf(TaskContext.get())
+ } else {
+ if (Utils.isTesting && SparkContext.getActive.isDefined) {
+ // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter`
+ // will return `fallbackConf` which is unexpected. Here we prevent it from happening.
+ val schedulerEventLoopThread =
+ SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread
+ if (schedulerEventLoopThread.getId == Thread.currentThread().getId) {
+ throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.")
+ }
+ }
+ confGetter.get()()
}
- confGetter.get()()
}
val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
@@ -1161,8 +1172,17 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val SQL_OPTIONS_REDACTION_PATTERN =
+ buildConf("spark.sql.redaction.options.regex")
+ .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " +
+ "information. The values of options whose names that match this regex will be redacted " +
+ "in the explain output. This redaction is applied on top of the global redaction " +
+ s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.")
+ .regexConf
+ .createWithDefault("(?i)url".r)
+
val SQL_STRING_REDACTION_PATTERN =
- ConfigBuilder("spark.sql.redaction.string.regex")
+ buildConf("spark.sql.redaction.string.regex")
.doc("Regex to decide which parts of strings produced by Spark contain sensitive " +
"information. When this regex matches a string part, that string part is replaced by a " +
"dummy value. This is currently used to redact the output of SQL explain commands. " +
@@ -1259,6 +1279,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val TOP_K_SORT_FALLBACK_THRESHOLD =
+ buildConf("spark.sql.execution.topKSortFallbackThreshold")
+ .internal()
+ .doc("In SQL queries with a SORT followed by a LIMIT like " +
+ "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" +
+ " in memory, otherwise do a global sort which spills to disk if necessary.")
+ .intConf
+ .createWithDefault(Int.MaxValue)
+
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
@@ -1266,6 +1295,13 @@ object SQLConf {
object Replaced {
val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces"
}
+
+ val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled")
+ .internal()
+ .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " +
+ "Other column values can be ignored during parsing even if they are malformed.")
+ .booleanConf
+ .createWithDefault(true)
}
/**
@@ -1284,7 +1320,7 @@ class SQLConf extends Serializable with Logging {
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
- @transient private val reader = new ConfigReader(settings)
+ @transient protected val reader = new ConfigReader(settings)
/** ************************ Spark SQL Params/Hints ******************* */
@@ -1420,10 +1456,12 @@ class SQLConf extends Serializable with Logging {
def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR)
- def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)
+ def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN)
def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
+ def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD)
+
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
@@ -1727,6 +1765,17 @@ class SQLConf extends Serializable with Logging {
}.toSeq
}
+ /**
+ * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN.
+ */
+ def redactOptions(options: Map[String, String]): Map[String, String] = {
+ val regexes = Seq(
+ getConf(SQL_OPTIONS_REDACTION_PATTERN),
+ SECRET_REDACTION_PATTERN.readFrom(reader))
+
+ regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap
+ }
+
/**
* Return whether a given key is set in this [[SQLConf]].
*/
@@ -1734,7 +1783,7 @@ class SQLConf extends Serializable with Logging {
settings.containsKey(key)
}
- private def setConfWithCheck(key: String, value: String): Unit = {
+ protected def setConfWithCheck(key: String, value: String): Unit = {
settings.put(key, value)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 4ee12db9c10ca..0bef11659fc9e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType {
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def sameType(other: DataType): Boolean =
- DataType.equalsIgnoreNullability(this, other)
+ if (SQLConf.get.caseSensitiveAnalysis) {
+ DataType.equalsIgnoreNullability(this, other)
+ } else {
+ DataType.equalsIgnoreCaseAndNullability(this, other)
+ }
/**
* Returns the same data type but set all nullability fields are true
@@ -214,7 +218,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
- private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+ private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index f73e045685ee1..0acd3b490447d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest {
}
private def checkWidenType(
- widenFunc: (DataType, DataType, Boolean) => Option[DataType],
+ widenFunc: (DataType, DataType) => Option[DataType],
t1: DataType,
t2: DataType,
expected: Option[DataType],
isSymmetric: Boolean = true): Unit = {
- var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis)
+ var found = widenFunc(t1, t2)
assert(found == expected,
s"Expected $expected as wider common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
if (isSymmetric) {
- found = widenFunc(t2, t1, conf.caseSensitiveAnalysis)
+ found = widenFunc(t2, t1)
assert(found == expected,
s"Expected $expected as wider common type for $t2 and $t1, found $found")
}
@@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._
- ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))
- ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeUnaryExpression(Literal.create(null, NullType)),
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}
@@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for binary operators") {
import TypeCoercionSuite._
- ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))
- ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}
test("coalesce casts") {
- val rule = TypeCoercion.FunctionArgumentConversion(conf)
+ val rule = TypeCoercion.FunctionArgumentConversion
val intLit = Literal(1)
val longLit = Literal.create(1L)
@@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("CreateArray casts") {
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
@@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal("a")
@@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal("a"), StringType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal(1)
:: Nil),
@@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal(1).cast(DecimalType(13, 3))
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
@@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("CreateMap casts") {
// type coercion for map keys
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
@@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(2.0, FloatType), FloatType)
:: Literal("b")
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
@@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal("b")
:: Nil))
// type coercion for map values
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2)
@@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal(2)
:: Cast(Literal(3.0), StringType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
@@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
:: Nil))
// type coercion for both map keys and values
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2.0)
@@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("greatest/least cast") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
@@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
@@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1.0)
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
@@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
:: Literal(1).cast(DoubleType)
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal.create(null, DecimalType(15, 0))
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
@@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
:: Literal(1).cast(DecimalType(20, 5))
:: Nil))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal.create(2L, LongType)
:: Literal(1)
:: Literal.create(null, DecimalType(10, 5))
@@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("nanvl casts") {
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
- ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
+ ruleTest(TypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
}
test("type coercion for If") {
- val rule = TypeCoercion.IfCoercion(conf)
+ val rule = TypeCoercion.IfCoercion
val intLit = Literal(1)
val doubleLit = Literal(1.0)
val trueLit = Literal.create(true, BooleanType)
@@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("type coercion for CaseKeyWhen") {
- ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
+ ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
- ruleTest(TypeCoercion.CaseWhenCoercion(conf),
+ ruleTest(TypeCoercion.CaseWhenCoercion,
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
- ruleTest(TypeCoercion.CaseWhenCoercion(conf),
+ ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
)
- ruleTest(TypeCoercion.CaseWhenCoercion(conf),
+ ruleTest(TypeCoercion.CaseWhenCoercion,
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
@@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest {
private val timeZoneResolver = ResolveTimeZone(new SQLConf)
private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
- timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan))
+ timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
}
test("WidenSetOperationTypes for except and intersect") {
@@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest {
test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
"in aggregation function like sum") {
- val rules = Seq(FunctionArgumentConversion(conf), Division)
+ val rules = Seq(FunctionArgumentConversion, Division)
// Casts Integer to Double
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
@@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest {
}
test("SPARK-17117 null type coercion in divide") {
- val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf))
+ val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf))
val nullLit = Literal.create(null, NullType)
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index a2851d071c7c6..3fc0b08c56e02 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapValues(m2), null)
}
+ test("MapEntries") {
+ def r(values: Any*): InternalRow = create_row(values: _*)
+
+ // Primitive-type keys/values
+ val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType))
+ val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType))
+ val mi2 = Literal.create(null, MapType(IntegerType, IntegerType))
+
+ checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2)))
+ checkEvaluation(MapEntries(mi1), Seq.empty)
+ checkEvaluation(MapEntries(mi2), null)
+
+ // Non-primitive-type keys/values
+ val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType))
+ val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType))
+ val ms2 = Literal.create(null, MapType(StringType, StringType))
+
+ checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null)))
+ checkEvaluation(MapEntries(ms1), Seq.empty)
+ checkEvaluation(MapEntries(ms2), null)
+ }
+
test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
@@ -134,6 +157,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
+
+ // binary
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
+ ArrayType(BinaryType))
+ val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val be = Literal.create(Array[Byte](1, 2), BinaryType)
+ val nullBinary = Literal.create(null, BinaryType)
+
+ checkEvaluation(ArrayContains(b0, be), true)
+ checkEvaluation(ArrayContains(b1, be), false)
+ checkEvaluation(ArrayContains(b0, nullBinary), null)
+ checkEvaluation(ArrayContains(b2, be), null)
+ checkEvaluation(ArrayContains(b3, be), true)
+
+ // complex data types
+ val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
+ val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+ checkEvaluation(ArrayContains(aa0, aae), true)
+ checkEvaluation(ArrayContains(aa1, aae), false)
+ }
+
+ test("ArraysOverlap") {
+ val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
+ val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
+ val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
+ val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
+ val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
+ val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))
+
+ val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType))
+
+ checkEvaluation(ArraysOverlap(a0, a1), true)
+ checkEvaluation(ArraysOverlap(a0, a2), null)
+ checkEvaluation(ArraysOverlap(a1, a2), true)
+ checkEvaluation(ArraysOverlap(a1, a3), false)
+ checkEvaluation(ArraysOverlap(a0, emptyIntArray), false)
+ checkEvaluation(ArraysOverlap(a2, emptyIntArray), false)
+ checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
+
+ checkEvaluation(ArraysOverlap(a4, a5), true)
+ checkEvaluation(ArraysOverlap(a4, a6), null)
+ checkEvaluation(ArraysOverlap(a5, a6), false)
+
+ // null handling
+ checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
+ checkEvaluation(ArraysOverlap(
+ emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false)
+ checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
+ checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
+ checkEvaluation(ArraysOverlap(
+ Literal.create(Seq(null), ArrayType(IntegerType)),
+ Literal.create(Seq(null), ArrayType(IntegerType))), null)
+
+ // arrays of binaries
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)),
+ ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
+ ArrayType(BinaryType))
+
+ checkEvaluation(ArraysOverlap(b0, b1), true)
+ checkEvaluation(ArraysOverlap(b0, b2), false)
+
+ // arrays of complex data types
+ val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")),
+ ArrayType(ArrayType(StringType)))
+ val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")),
+ ArrayType(ArrayType(StringType)))
+ val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")),
+ ArrayType(ArrayType(StringType)))
+
+ checkEvaluation(ArraysOverlap(aa0, aa1), true)
+ checkEvaluation(ArraysOverlap(aa0, aa2), false)
+
+ // null handling with complex datatypes
+ val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType))
+ val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
+ checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false)
+ checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false)
+ checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false)
+ checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false)
+ checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null)
+ checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null)
}
test("Slice") {
@@ -283,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayPosition(a3, Literal("")), null)
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
+
+ val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
+ val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
+ checkEvaluation(ArrayPosition(aa0, aae), 1L)
+ checkEvaluation(ArrayPosition(aa1, aae), 0L)
}
test("elementAt") {
@@ -320,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(null, MapType(StringType, StringType))
- checkEvaluation(ElementAt(m0, Literal(1.0)), null)
+ assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure)
checkEvaluation(ElementAt(m0, Literal("d")), null)
@@ -331,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(m0, Literal("c")), null)
checkEvaluation(ElementAt(m2, Literal("a")), null)
+
+ // test binary type as keys
+ val mb0 = Literal.create(
+ Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
+ MapType(BinaryType, StringType))
+ val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
+
+ checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null)
+
+ checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null)
+ checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
+ checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
}
test("Concat") {
@@ -468,4 +604,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Flatten(asa3), null)
checkEvaluation(Flatten(asa4), null)
}
+
+ test("ArrayRepeat") {
+ val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType))
+ val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType))
+
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq())
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi"))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi"))
+ checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true))
+ checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1))
+ checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2))
+ checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null))
+ checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null))
+ checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2)))
+ checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola")))
+ checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index a22e9d4655e8c..c2a44e0d33b18 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
if (expected.isNaN) result.isNaN else expected == result
case (result: Float, expected: Float) =>
if (expected.isNaN) result.isNaN else expected == result
+ case (result: UnsafeRow, expected: GenericInternalRow) =>
+ val structType = exprDataType.asInstanceOf[StructType]
+ result.toSeq(structType) == expected.toSeq(structType)
case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema)
case _ =>
result == expected
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
index 64b65e2070ed6..7c7c4cccee253 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression {
override def eval(input: InternalRow): Any = 10
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
ev.copy(code =
- s"""
+ code"""
|int some_variable = 11;
|int ${ev.value} = 10;
""".stripMargin)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
new file mode 100644
index 0000000000000..d2c6420eadb20
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class CodeBlockSuite extends SparkFunSuite {
+
+ test("Block interpolates string and ExprValue inputs") {
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val stringLiteral = "false"
+ val code = code"boolean $isNull = $stringLiteral;"
+ assert(code.toString == "boolean expr1_isNull = false;")
+ }
+
+ test("Literals are folded into string code parts instead of block inputs") {
+ val value = JavaCode.variable("expr1", IntegerType)
+ val intLiteral = 1
+ val code = code"int $value = $intLiteral;"
+ assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
+ }
+
+ test("Block.stripMargin") {
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val value = JavaCode.variable("expr1", IntegerType)
+ val code1 =
+ code"""
+ |boolean $isNull = false;
+ |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin
+ val expected =
+ s"""
+ |boolean expr1_isNull = false;
+ |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim
+ assert(code1.toString == expected)
+
+ val code2 =
+ code"""
+ >boolean $isNull = false;
+ >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>')
+ assert(code2.toString == expected)
+ }
+
+ test("Block can capture input expr values") {
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val value = JavaCode.variable("expr1", IntegerType)
+ val code =
+ code"""
+ |boolean $isNull = false;
+ |int $value = -1;
+ """.stripMargin
+ val exprValues = code.exprValues
+ assert(exprValues.size == 2)
+ assert(exprValues === Set(value, isNull))
+ }
+
+ test("concatenate blocks") {
+ val isNull1 = JavaCode.isNullVariable("expr1_isNull")
+ val value1 = JavaCode.variable("expr1", IntegerType)
+ val isNull2 = JavaCode.isNullVariable("expr2_isNull")
+ val value2 = JavaCode.variable("expr2", IntegerType)
+ val literal = JavaCode.literal("100", IntegerType)
+
+ val code =
+ code"""
+ |boolean $isNull1 = false;
+ |int $value1 = -1;""".stripMargin +
+ code"""
+ |boolean $isNull2 = true;
+ |int $value2 = $literal;""".stripMargin
+
+ val expected =
+ """
+ |boolean expr1_isNull = false;
+ |int expr1 = -1;
+ |boolean expr2_isNull = true;
+ |int expr2 = 100;""".stripMargin.trim
+
+ assert(code.toString == expected)
+
+ val exprValues = code.exprValues
+ assert(exprValues.size == 5)
+ assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
+ }
+
+ test("Throws exception when interpolating unexcepted object in code block") {
+ val obj = Tuple2(1, 1)
+ val e = intercept[IllegalArgumentException] {
+ code"$obj"
+ }
+ assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}"))
+ }
+
+ test("replace expr values in code block") {
+ val expr = JavaCode.expression("1 + 1", IntegerType)
+ val isNull = JavaCode.isNullVariable("expr1_isNull")
+ val exprInFunc = JavaCode.variable("expr1", IntegerType)
+
+ val code =
+ code"""
+ |callFunc(int $expr) {
+ | boolean $isNull = false;
+ | int $exprInFunc = $expr + 1;
+ |}""".stripMargin
+
+ val aliasedParam = JavaCode.variable("aliased", expr.javaType)
+ val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
+ case _: SimpleExprValue => aliasedParam
+ case other => other
+ }
+ val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin
+ val expected =
+ code"""
+ |callFunc(int $aliasedParam) {
+ | boolean $isNull = false;
+ | int $exprInFunc = $aliasedParam + 1;
+ |}""".stripMargin
+ assert(aliasedCode.toString == expected.toString)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
index 633d86d495581..5452e72b38647 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala
@@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
.select('c as 'sCol2, 'a as 'sCol1)
checkRule(originalQuery, correctAnswer)
}
+
+ test("SPARK-24313: support binary type as map keys in GetMapValue") {
+ val mb0 = Literal.create(
+ Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
+ MapType(BinaryType, StringType))
+ val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
+
+ checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null)
+
+ checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null)
+ checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
+ checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 812bfdd7bb885..fb51376c6163f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select * from t lateral view posexplode(x) posexpl as x, y",
expected)
+
+ intercept(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |pivot (
+ | sum(x)
+ | FOR y IN ('a', 'b')
+ |)""".stripMargin,
+ "LATERAL cannot be used together with PIVOT in FROM clause")
}
test("joins") {
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index ef41837f89d68..f270c70fbfcf0 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -38,7 +38,7 @@
com.univocity
univocity-parsers
- 2.5.9
+ 2.6.3
jar
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index 10d6ed85a4080..daedfd7e78f5f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.parquet;
-import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index aacefacfc1c1a..c62dc3d86386e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -26,7 +26,6 @@
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.io.api.Binary;
-import org.apache.spark.unsafe.Platform;
/**
* An implementation of the Parquet PLAIN decoder that supports the vectorized interface.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
index 0ea4dc6b5def3..b2526ded53d92 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java
@@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 {
/**
* Creates a {@link DataSourceReader} to scan the data from this data source.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param options the options for the returned data source reader, which is an immutable
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
index 3801402268af1..f31659904cc53 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java
@@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 {
/**
* Create a {@link DataSourceReader} to scan the data from this data source.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param schema the full schema of this data source reader. Full schema usually maps to the
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
index cab56453816cc..83aeec0c47853 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java
@@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 {
* Creates an optional {@link DataSourceWriter} to save the data to this data source. Data
* sources can return None if there is no writing needed to be done according to the save mode.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param jobId A unique string for the writing job. It's possible that there are many writing
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
index c24f3b21eade1..dcb87715d0b6f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java
@@ -27,9 +27,9 @@
@InterfaceStability.Evolving
public interface ContinuousInputPartition extends InputPartition {
/**
- * Create a DataReader with particular offset as its startOffset.
+ * Create an input partition reader with particular offset as its startOffset.
*
- * @param offset offset want to set as the DataReader's startOffset.
+ * @param offset offset want to set as the input partition reader's startOffset.
*/
InputPartitionReader createContinuousReader(PartitionOffset offset);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
index f898c296e4245..36a3e542b5a11 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java
@@ -31,7 +31,7 @@
* {@link ReadSupport#createReader(DataSourceOptions)} or
* {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}.
* It can mix in various query optimization interfaces to speed up the data scan. The actual scan
- * logic is delegated to {@link InputPartition}s that are returned by
+ * logic is delegated to {@link InputPartition}s, which are returned by
* {@link #planInputPartitions()}.
*
* There are mainly 3 kinds of query optimizations:
@@ -45,8 +45,8 @@
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
*
- * If an exception was throw when applying any of these query optimizations, the action would fail
- * and no Spark job was submitted.
+ * If an exception was throw when applying any of these query optimizations, the action will fail
+ * and no Spark job will be submitted.
*
* Spark first applies all operator push-down optimizations that this data source supports. Then
* Spark collects information this data source reported for further optimizations. Finally Spark
@@ -59,21 +59,21 @@ public interface DataSourceReader {
* Returns the actual schema of this data source reader, which may be different from the physical
* schema of the underlying storage, as column pruning or other optimizations may happen.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
StructType readSchema();
/**
- * Returns a list of read tasks. Each task is responsible for creating a data reader to
- * output data for one RDD partition. That means the number of tasks returned here is same as
- * the number of RDD partitions this scan outputs.
+ * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for
+ * creating a data reader to output data of one RDD partition. The number of input partitions
+ * returned here is the same as the number of RDD partitions this scan outputs.
*
* Note that, this may not be a full scan if the data source reader mixes in other optimization
* interfaces like column pruning, filter push-down, etc. These optimizations are applied before
* Spark issues the scan request.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
List> planInputPartitions();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
index c581e3b5d0047..f2038d0de3ffe 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
@@ -23,28 +23,29 @@
/**
* An input partition returned by {@link DataSourceReader#planInputPartitions()} and is
- * responsible for creating the actual data reader. The relationship between
- * {@link InputPartition} and {@link InputPartitionReader}
+ * responsible for creating the actual data reader of one RDD partition.
+ * The relationship between {@link InputPartition} and {@link InputPartitionReader}
* is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}.
*
- * Note that input partitions will be serialized and sent to executors, then the partition reader
- * will be created on executors and do the actual reading. So {@link InputPartition} must be
- * serializable and {@link InputPartitionReader} doesn't need to be.
+ * Note that {@link InputPartition}s will be serialized and sent to executors, then
+ * {@link InputPartitionReader}s will be created on executors to do the actual reading. So
+ * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to
+ * be.
*/
@InterfaceStability.Evolving
public interface InputPartition extends Serializable {
/**
- * The preferred locations where the data reader returned by this partition can run faster,
- * but Spark does not guarantee to run the data reader on these locations.
+ * The preferred locations where the input partition reader returned by this partition can run
+ * faster, but Spark does not guarantee to run the input partition reader on these locations.
* The implementations should make sure that it can be run on any location.
* The location is a string representing the host name.
*
* Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in
- * the returned locations. By default this method returns empty string array, which means this
- * task has no location preference.
+ * the returned locations. The default return value is empty string array, which means this
+ * input partition's reader has no location preference.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
default String[] preferredLocations() {
@@ -52,7 +53,7 @@ default String[] preferredLocations() {
}
/**
- * Returns a data reader to do the actual reading work.
+ * Returns an input partition reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
index 1b7051f1ad0af..33fa7be4c1b20 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java
@@ -23,12 +23,12 @@
import org.apache.spark.annotation.InterfaceStability;
/**
- * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for
- * outputting data for a RDD partition.
+ * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is
+ * responsible for outputting data for a RDD partition.
*
- * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
- * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source
- * readers that mix in {@link SupportsScanUnsafeRow}.
+ * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input
+ * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input
+ * partition readers that mix in {@link SupportsScanUnsafeRow}.
*/
@InterfaceStability.Evolving
public interface InputPartitionReader extends Closeable {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
index d2ee9518d628f..5e32ba6952e1c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
@@ -22,7 +22,8 @@
/**
* An interface to represent data distribution requirement, which specifies how the records should
- * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition).
+ * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one
+ * partition).
* Note that this interface has nothing to do with the data ordering inside one
* partition(the output records of a single {@link InputPartitionReader}).
*
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
index 716c5c0e9e15a..6e960bedf8020 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java
@@ -35,8 +35,8 @@
@InterfaceStability.Evolving
public interface ContinuousReader extends BaseStreamingSource, DataSourceReader {
/**
- * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each
- * partition to a single global offset.
+ * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances
+ * for each partition to a single global offset.
*/
Offset mergeOffsets(PartitionOffset[] offsets);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
index 0a0fd8db58035..0030a9f05dba7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java
@@ -34,8 +34,8 @@
* It can mix in various writing optimization interfaces to speed up the data saving. The actual
* writing logic is delegated to {@link DataWriter}.
*
- * If an exception was throw when applying any of these writing optimizations, the action would fail
- * and no Spark job was submitted.
+ * If an exception was throw when applying any of these writing optimizations, the action will fail
+ * and no Spark job will be submitted.
*
* The writing procedure is:
* 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the
@@ -58,7 +58,7 @@ public interface DataSourceWriter {
/**
* Creates a writer factory which will be serialized and sent to executors.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*/
DataWriterFactory createWriterFactory();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
index c2c2ab73257e8..7527bcc0c4027 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
@@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable {
/**
* Returns a data writer to do the actual writing work.
*
- * If this method fails (by throwing an exception), the action would fail and no Spark job was
+ * If this method fails (by throwing an exception), the action will fail and no Spark job will be
* submitted.
*
* @param partitionId A unique id of the RDD partition that the returned writer will process.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 53f44888ebaff..917f0cb221412 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
* should be included. "fetchsize" can be used to control the
- * number of rows per fetch.
+ * number of rows per fetch and "queryTimeout" can be used to wait
+ * for a Statement object to execute to the given number of seconds.
* @since 1.4.0
*/
def jdbc(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index d518e07bfb62c..32267eb0300f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -511,6 +511,16 @@ class Dataset[T] private[sql](
*/
def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
+ /**
+ * Returns true if the `Dataset` is empty.
+ *
+ * @group basic
+ * @since 2.4.0
+ */
+ def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan =>
+ plan.executeCollect().head.getLong(0) == 0
+ }
+
/**
* Returns true if this Dataset contains one or more sources that continuously
* return data as it arrives. A Dataset that reads data from a streaming source
@@ -1607,7 +1617,9 @@ class Dataset[T] private[sql](
*/
@Experimental
@InterfaceStability.Evolving
- def reduce(func: (T, T) => T): T = rdd.reduce(func)
+ def reduce(func: (T, T) => T): T = withNewRDDExecutionId {
+ rdd.reduce(func)
+ }
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index fc3dbc1c5591b..48abad9078650 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
}
val valueVar = ctx.freshName("value")
val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
- val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
- s"""
+ val code = code"${ctx.registerComment(str)}" + (if (nullable) {
+ code"""
boolean $isNullVar = $columnVar.isNullAt($ordinal);
$javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value);
"""
} else {
- s"$javaType $valueVar = $value;"
- }).trim
+ code"$javaType $valueVar = $value;"
+ })
ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 08ff33afbba3d..61c14fee09337 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
* Shorthand for calling redactString() without specifying redacting rules
*/
private def redact(text: String): String = {
- Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text)
+ Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index e4812f3d338fb..5b4edf5136e3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -152,7 +153,7 @@ case class ExpandExec(
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
- val code = s"""
+ val code = code"""
|boolean $isNull = true;
|${CodeGenerator.javaType(firstExpr.dataType)} $value =
| ${CodeGenerator.defaultValue(firstExpr.dataType)};
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index f40c50df74ccb..2549b9e1537a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types._
@@ -313,13 +314,13 @@ case class GenerateExec(
if (checks.nonEmpty) {
val isNull = ctx.freshName("isNull")
val code =
- s"""
+ code"""
|boolean $isNull = ${checks.mkString(" || ")};
|$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter;
""".stripMargin
ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt))
} else {
- ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
+ ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 15379a0663f7d..3112b306c365e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -225,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
* Redact the sensitive information in the given string.
*/
private def withRedaction(message: String): String = {
- Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message)
+ Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message)
}
/** A special namespace for commands that can be used to debug query execution. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 2c5102b1e5ee7..439932b0cc3ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -68,16 +68,18 @@ object SQLExecution {
// sparkContext.getCallSite() would first try to pick up any call site that was previously
// set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on
// streaming queries would give us call site like "run at :0"
- val callSite = sparkSession.sparkContext.getCallSite()
+ val callSite = sc.getCallSite()
- sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart(
- executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
- SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
- try {
- body
- } finally {
- sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd(
- executionId, System.currentTimeMillis()))
+ withSQLConfPropagated(sparkSession) {
+ sc.listenerBus.post(SparkListenerSQLExecutionStart(
+ executionId, callSite.shortForm, callSite.longForm, queryExecution.toString,
+ SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis()))
+ try {
+ body
+ } finally {
+ sc.listenerBus.post(SparkListenerSQLExecutionEnd(
+ executionId, System.currentTimeMillis()))
+ }
}
} finally {
executionIdToQueryExecution.remove(executionId)
@@ -90,13 +92,41 @@ object SQLExecution {
* thread from the original one, this method can be used to connect the Spark jobs in this action
* with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`.
*/
- def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = {
+ def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = {
+ val sc = sparkSession.sparkContext
val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ withSQLConfPropagated(sparkSession) {
+ try {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
+ body
+ } finally {
+ sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ }
+ }
+ }
+
+ /**
+ * Wrap an action with specified SQL configs. These configs will be propagated to the executor
+ * side via job local properties.
+ */
+ def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = {
+ val sc = sparkSession.sparkContext
+ // Set all the specified SQL configs to local properties, so that they can be available at
+ // the executor side.
+ val allConfigs = sparkSession.sessionState.conf.getAllConfs
+ val originalLocalProps = allConfigs.collect {
+ case (key, value) if key.startsWith("spark") =>
+ val originalValue = sc.getLocalProperty(key)
+ sc.setLocalProperty(key, value)
+ (key, originalValue)
+ }
+
try {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId)
body
} finally {
- sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId)
+ for ((key, value) <- originalLocalProps) {
+ sc.setLocalProperty(key, value)
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 82b4eb9fba242..b97a87a122406 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -66,9 +66,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object SpecialLimits extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ReturnAnswer(rootPlan) => rootPlan match {
- case Limit(IntegerLiteral(limit), Sort(order, true, child)) =>
+ case Limit(IntegerLiteral(limit), Sort(order, true, child))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) =>
+ case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case Limit(IntegerLiteral(limit), child) =>
// With whole stage codegen, Spark releases resources only when all the output data of the
@@ -79,9 +81,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
case other => planLater(other) :: Nil
}
- case Limit(IntegerLiteral(limit), Sort(order, true, child)) =>
+ case Limit(IntegerLiteral(limit), Sort(order, true, child))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) =>
+ case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child)))
+ if limit < conf.topKSortFallbackThreshold =>
TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case _ => Nil
}
@@ -361,7 +365,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case Join(left, right, _, _) if left.isStreaming && right.isStreaming =>
throw new AnalysisException(
- "Stream stream joins without equality predicate is not supported", plan = Some(plan))
+ "Stream-stream join without equality predicate is not supported", plan = Some(plan))
case _ => Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 828b51fa199de..372dc3db36ce6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan {
ctx.INPUT_ROW = row
ctx.currentVars = colVars
val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
- val code = s"""
+ val code = code"""
|$evaluateInputs
- |${ev.code.trim}
- """.stripMargin.trim
+ |${ev.code}
+ """.stripMargin
ExprCode(code, FalseLiteral, ev.value)
} else {
// There are no columns
@@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan {
* them to be evaluated twice.
*/
protected def evaluateVariables(variables: Seq[ExprCode]): String = {
- val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
- variables.foreach(_.code = "")
+ val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n")
+ variables.foreach(_.code = EmptyBlock)
evaluate
}
@@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan {
val evaluateVars = new StringBuilder
variables.zipWithIndex.foreach { case (ev, i) =>
if (ev.code != "" && required.contains(attributes(i))) {
- evaluateVars.append(ev.code.trim + "\n")
- ev.code = ""
+ evaluateVars.append(ev.code.toString + "\n")
+ ev.code = EmptyBlock
}
}
evaluateVars.toString()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 6a8ec4f722aea..8c7b2c187cccd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -190,7 +191,7 @@ case class HashAggregateExec(
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
// The initial expression should not access any column
val ev = e.genCode(ctx)
- val initVars = s"""
+ val initVars = code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
@@ -773,8 +774,8 @@ case class HashAggregateExec(
val findOrInsertRegularHashMap: String =
s"""
|// generate grouping key
- |${unsafeRowKeyCode.code.trim}
- |${hashEval.code.trim}
+ |${unsafeRowKeyCode.code}
+ |${hashEval.code}
|if ($checkFallbackForBytesToBytesMap) {
| // try to get the buffer from hash map
| $unsafeRowBuffer =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
index de2d630de3fdb..e1c85823259b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate}
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
/**
@@ -50,7 +51,7 @@ abstract class HashMapGenerator(
val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue")
val ev = e.genCode(ctx)
val initVars =
- s"""
+ code"""
| $isNull = ${ev.isNull};
| $value = ${ev.value};
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 22b63513548fe..66888fce7f9f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter {
valueVector match {
case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset()
case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset()
+ case listVector: ListVector =>
+ // Manual "reset" the underlying buffer.
+ // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call
+ // `listVector.reset()`.
+ val buffers = listVector.getBuffers(false)
+ buffers.foreach(buf => buf.setZero(0, buf.capacity()))
+ listVector.setValueCount(0)
+ listVector.setLastSet(0)
case _ =>
}
count = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 1edfdc888afd8..9434ceb7cd16c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
override val output: Seq[Attribute] = range.output
+ override def outputOrdering: Seq[SortOrder] = range.outputOrdering
+
+ override def outputPartitioning: Partitioning = {
+ if (numElements > 0) {
+ if (numSlices == 1) {
+ SinglePartition
+ } else {
+ RangePartitioning(outputOrdering, numSlices)
+ }
+ } else {
+ UnknownPartitioning(0)
+ }
+ }
+
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -629,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
+ SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
val beforeCollect = System.nanoTime()
// Note that we use .executeCollect() because we don't want to convert data to Scala types
val rows: Array[InternalRow] = child.executeCollect()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 1edf27619ad7b..f9a24806953e6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
@@ -522,8 +521,6 @@ object PartitioningUtils {
private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = {
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType
case (DoubleType, LongType) | (LongType, DoubleType) => StringType
- case (t1, t2) =>
- TypeCoercion.findWiderTypeForTwo(
- t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType)
+ case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
index 568e953a5db66..00b1b5dedb593 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.SparkEnv
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.CreatableRelationProvider
-import org.apache.spark.util.Utils
/**
* Saves the results of `query` in to a data source.
@@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand(
}
override def simpleString: String = {
- val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap
+ val redacted = SQLConf.get.redactOptions(options)
s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index ed2dc65a47914..dd41aee0f2ebc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
class CSVOptions(
@transient val parameters: CaseInsensitiveMap[String],
@@ -80,6 +81,8 @@ class CSVOptions(
}
}
+ private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING)
+
val delimiter = CSVUtils.toChar(
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
val parseMode: ParseMode =
@@ -164,7 +167,7 @@ class CSVOptions(
writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite)
writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite)
writerSettings.setNullValue(nullValue)
- writerSettings.setEmptyValue(nullValue)
+ writerSettings.setEmptyValue("\"\"")
writerSettings.setSkipEmptyLines(true)
writerSettings.setQuoteAllFields(quoteAll)
writerSettings.setQuoteEscapingEnabled(escapeQuotes)
@@ -185,6 +188,7 @@ class CSVOptions(
settings.setInputBufferSize(inputBufferSize)
settings.setMaxColumns(maxColumns)
settings.setNullValue(nullValue)
+ settings.setEmptyValue("")
settings.setMaxCharsPerColumn(maxCharsPerColumn)
settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER)
settings
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 99557a1ceb0c8..4f00cc5eb3f39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -34,10 +34,10 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class UnivocityParser(
- schema: StructType,
+ dataSchema: StructType,
requiredSchema: StructType,
val options: CSVOptions) extends Logging {
- require(requiredSchema.toSet.subsetOf(schema.toSet),
+ require(requiredSchema.toSet.subsetOf(dataSchema.toSet),
"requiredSchema should be the subset of schema.")
def this(schema: StructType, options: CSVOptions) = this(schema, schema, options)
@@ -45,9 +45,17 @@ class UnivocityParser(
// A `ValueConverter` is responsible for converting the given value to a desired type.
private type ValueConverter = String => Any
- private val tokenizer = new CsvParser(options.asParserSettings)
+ private val tokenizer = {
+ val parserSetting = options.asParserSettings
+ if (options.columnPruning && requiredSchema.length < dataSchema.length) {
+ val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f)))
+ parserSetting.selectIndexes(tokenIndexArr: _*)
+ }
+ new CsvParser(parserSetting)
+ }
+ private val schema = if (options.columnPruning) requiredSchema else dataSchema
- private val row = new GenericInternalRow(requiredSchema.length)
+ private val row = new GenericInternalRow(schema.length)
// Retrieve the raw record string.
private def getCurrentInput: UTF8String = {
@@ -73,11 +81,8 @@ class UnivocityParser(
// Each input token is placed in each output row's position by mapping these. In this case,
//
// output row - ["A", 2]
- private val valueConverters: Array[ValueConverter] =
+ private val valueConverters: Array[ValueConverter] = {
schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
-
- private val tokenIndexArr: Array[Int] = {
- requiredSchema.map(f => schema.indexOf(f)).toArray
}
/**
@@ -210,9 +215,8 @@ class UnivocityParser(
} else {
try {
var i = 0
- while (i < requiredSchema.length) {
- val from = tokenIndexArr(i)
- row(i) = valueConverters(from).apply(tokens(from))
+ while (i < schema.length) {
+ row(i) = valueConverters(i).apply(tokens(i))
i += 1
}
row
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index b4e5d169066d9..a73a97c06fe5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -89,6 +89,10 @@ class JDBCOptions(
// the number of partitions
val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt)
+ // the number of seconds the driver will wait for a Statement object to execute to the given
+ // number of seconds. Zero means there is no limit.
+ val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt
+
// ------------------------------------------------------------
// Optional parameters only for reading
// ------------------------------------------------------------
@@ -160,6 +164,7 @@ object JDBCOptions {
val JDBC_LOWER_BOUND = newOption("lowerBound")
val JDBC_UPPER_BOUND = newOption("upperBound")
val JDBC_NUM_PARTITIONS = newOption("numPartitions")
+ val JDBC_QUERY_TIMEOUT = newOption("queryTimeout")
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 05326210f3242..0bab3689e5d0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -57,6 +57,7 @@ object JDBCRDD extends Logging {
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
try {
+ statement.setQueryTimeout(options.queryTimeout)
val rs = statement.executeQuery()
try {
JdbcUtils.getSchema(rs, dialect, alwaysNullable = true)
@@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD(
val statement = conn.prepareStatement(sql)
logInfo(s"Executing sessionInitStatement: $sql")
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.execute()
} finally {
statement.close()
@@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD(
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
+ stmt.setQueryTimeout(options.queryTimeout)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index cc506e51bd0c6..f8c5677ea0f2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
saveTable(df, tableSchema, isCaseSensitive, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
- dropTable(conn, options.table)
+ dropTable(conn, options.table, options)
createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index e6dc2fda4eb1b..433443007cfd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -76,6 +76,7 @@ object JdbcUtils extends Logging {
Try {
val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table))
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeQuery()
} finally {
statement.close()
@@ -86,9 +87,10 @@ object JdbcUtils extends Logging {
/**
* Drops a table from the JDBC database.
*/
- def dropTable(conn: Connection, table: String): Unit = {
+ def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = {
val statement = conn.createStatement
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(s"DROP TABLE $table")
} finally {
statement.close()
@@ -102,6 +104,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(options.url)
val statement = conn.createStatement
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(dialect.getTruncateQuery(options.table))
} finally {
statement.close()
@@ -254,6 +257,7 @@ object JdbcUtils extends Logging {
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table))
try {
+ statement.setQueryTimeout(options.queryTimeout)
Some(getSchema(statement.executeQuery(), dialect))
} catch {
case _: SQLException => None
@@ -596,7 +600,8 @@ object JdbcUtils extends Logging {
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
- isolationLevel: Int): Iterator[Byte] = {
+ isolationLevel: Int,
+ options: JDBCOptions): Iterator[Byte] = {
val conn = getConnection()
var committed = false
@@ -637,6 +642,9 @@ object JdbcUtils extends Logging {
try {
var rowCount = 0
+
+ stmt.setQueryTimeout(options.queryTimeout)
+
while (iterator.hasNext) {
val row = iterator.next()
var i = 0
@@ -819,7 +827,8 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
- getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
+ getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
+ options)
)
}
@@ -841,6 +850,7 @@ object JdbcUtils extends Logging {
val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions"
val statement = conn.createStatement
try {
+ statement.setQueryTimeout(options.queryTimeout)
statement.executeUpdate(sql)
} finally {
statement.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index ba83df0efebd0..3b6df45e949e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
@@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource {
CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow)
}.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow))
- JsonInferSchema.infer(rdd, parsedOptions, rowParser)
+ SQLExecution.withSQLConfPropagated(json.sparkSession) {
+ JsonInferSchema.infer(rdd, parsedOptions, rowParser)
+ }
}
private def createBaseDataset(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: JSONOptions): Dataset[String] = {
- val paths = inputPaths.map(_.getPath.toString)
- val textOptions = Map.empty[String, String] ++
- parsedOptions.encoding.map("encoding" -> _) ++
- parsedOptions.lineSeparator.map("lineSep" -> _)
-
sparkSession.baseRelationToDataFrame(
DataSource.apply(
sparkSession,
- paths = paths,
+ paths = inputPaths.map(_.getPath.toString),
className = classOf[TextFileFormat].getName,
options = parsedOptions.parameters
).resolveRelation(checkFilesExist = false))
@@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource {
.map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream))
.getOrElse(createParser(_: JsonFactory, _: PortableDataStream))
- JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
+ SQLExecution.withSQLConfPropagated(sparkSession) {
+ JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser)
+ }
}
private def createBaseRdd(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index e0424b7478122..e7eed95a560a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -45,17 +44,17 @@ private[sql] object JsonInferSchema {
createParser: (JsonFactory, T) => JsonParser): StructType = {
val parseMode = configOptions.parseMode
val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
- val caseSensitive = SQLConf.get.caseSensitiveAnalysis
- // perform schema inference on each row and merge afterwards
- val rootType = json.mapPartitions { iter =>
+ // In each RDD partition, perform schema inference on each row and merge afterwards.
+ val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode)
+ val mergedTypesFromPartitions = json.mapPartitions { iter =>
val factory = new JsonFactory()
configOptions.setJacksonOptions(factory)
iter.flatMap { row =>
try {
Utils.tryWithResource(createParser(factory, row)) { parser =>
parser.nextToken()
- Some(inferField(parser, configOptions, caseSensitive))
+ Some(inferField(parser, configOptions))
}
} catch {
case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match {
@@ -68,9 +67,13 @@ private[sql] object JsonInferSchema {
s"Parse Mode: ${FailFastMode.name}.", e)
}
}
- }
- }.fold(StructType(Nil))(
- compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive))
+ }.reduceOption(typeMerger).toIterator
+ }
+
+ // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because
+ // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have
+ // active SparkSession and `SQLConf.get` may point to the wrong configs.
+ val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger)
canonicalizeType(rootType) match {
case Some(st: StructType) => st
@@ -100,15 +103,14 @@ private[sql] object JsonInferSchema {
/**
* Infer the type of a json document from the parser's token stream
*/
- private def inferField(
- parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = {
+ private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType
case FIELD_NAME =>
parser.nextToken()
- inferField(parser, configOptions, caseSensitive)
+ inferField(parser, configOptions)
case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
@@ -125,7 +127,7 @@ private[sql] object JsonInferSchema {
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(
parser.getCurrentName,
- inferField(parser, configOptions, caseSensitive),
+ inferField(parser, configOptions),
nullable = true)
}
val fields: Array[StructField] = builder.result()
@@ -140,7 +142,7 @@ private[sql] object JsonInferSchema {
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(
- elementType, inferField(parser, configOptions, caseSensitive), caseSensitive)
+ elementType, inferField(parser, configOptions))
}
ArrayType(elementType)
@@ -246,14 +248,13 @@ private[sql] object JsonInferSchema {
*/
private def compatibleRootType(
columnNameOfCorruptRecords: String,
- parseMode: ParseMode,
- caseSensitive: Boolean): (DataType, DataType) => DataType = {
+ parseMode: ParseMode): (DataType, DataType) => DataType = {
// Since we support array of json objects at the top level,
// we need to check the element type and find the root level data type.
case (ArrayType(ty1, _), ty2) =>
- compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2)
+ compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
case (ty1, ArrayType(ty2, _)) =>
- compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2)
+ compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2)
// Discard null/empty documents
case (struct: StructType, NullType) => struct
case (NullType, struct: StructType) => struct
@@ -263,7 +264,7 @@ private[sql] object JsonInferSchema {
withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode)
// If we get anything else, we call compatibleType.
// Usually, when we reach here, ty1 and ty2 are two StructTypes.
- case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive)
+ case (ty1, ty2) => compatibleType(ty1, ty2)
}
private[this] val emptyStructFieldArray = Array.empty[StructField]
@@ -271,8 +272,8 @@ private[sql] object JsonInferSchema {
/**
* Returns the most general data type for two given data types.
*/
- def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = {
- TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse {
+ def compatibleType(t1: DataType, t2: DataType): DataType = {
+ TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
@@ -307,8 +308,7 @@ private[sql] object JsonInferSchema {
val f2Name = fields2(f2Idx).name
val comp = f1Name.compareTo(f2Name)
if (comp == 0) {
- val dataType = compatibleType(
- fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive)
+ val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
newFields.add(StructField(f1Name, dataType, nullable = true))
f1Idx += 1
f2Idx += 1
@@ -331,17 +331,15 @@ private[sql] object JsonInferSchema {
StructType(newFields.toArray(emptyStructFieldArray))
case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
- ArrayType(
- compatibleType(elementType1, elementType2, caseSensitive),
- containsNull1 || containsNull2)
+ ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
// The case that given `DecimalType` is capable of given `IntegralType` is handled in
// `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
// the given `DecimalType` is not capable of the given `IntegralType`.
case (t1: IntegralType, t2: DecimalType) =>
- compatibleType(DecimalType.forType(t1), t2, caseSensitive)
+ compatibleType(DecimalType.forType(t1), t2)
case (t1: DecimalType, t2: IntegralType) =>
- compatibleType(t1, DecimalType.forType(t2), caseSensitive)
+ compatibleType(t1, DecimalType.forType(t2))
// strings and every string is a Json object.
case (_, _) => StringType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 0dea767840ed3..cab00251622b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
case _: ClassNotFoundException => u
case e: Exception =>
// the provider is valid, but failed to create a logical plan
- u.failAnalysis(e.getMessage)
+ u.failAnalysis(e.getMessage, e)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 1a6b32429313a..8d6fb3820d420 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -29,12 +29,12 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: I
class DataSourceRDD[T: ClassTag](
sc: SparkContext,
- @transient private val readerFactories: Seq[InputPartition[T]])
+ @transient private val inputPartitions: Seq[InputPartition[T]])
extends RDD[T](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
- readerFactories.zipWithIndex.map {
- case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
+ inputPartitions.zipWithIndex.map {
+ case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition)
}.toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
index 9293d4f831bff..e894f8afd6762 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala
@@ -23,17 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project
import org.apache.spark.sql.catalyst.rules.Rule
object PushDownOperatorsToDataSource extends Rule[LogicalPlan] {
- override def apply(
- plan: LogicalPlan): LogicalPlan = plan transformUp {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan match {
// PhysicalOperation guarantees that filters are deterministic; no need to check
- case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) =>
- // merge the filters
- val filters = relation.filters match {
- case Some(existing) =>
- existing ++ newFilters
- case _ =>
- newFilters
- }
+ case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
+ assert(relation.filters.isEmpty, "data source v2 should do push down only once.")
val projectAttrs = project.map(_.toAttribute)
val projectSet = AttributeSet(project.flatMap(_.references))
@@ -67,5 +60,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] {
} else {
filtered
}
+
+ case other => other.mapChildren(apply)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index daea6c39624d6..9e0ec9481b0de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -69,7 +69,7 @@ case class BroadcastExchangeExec(
Future {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
+ SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
try {
val beforeCollect = System.nanoTime()
// Use executeCollect/executeCollectIterator to avoid conversion to Scala types
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 6fa716d9fadee..0da0e8610c392 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
@@ -183,7 +184,7 @@ case class BroadcastHashJoinExec(
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
val javaType = CodeGenerator.javaType(a.dataType)
- val code = s"""
+ val code = code"""
|boolean $isNull = true;
|$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
|if ($matched != null) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index d8261f0f33b61..f4b9d132122e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
@@ -521,7 +522,7 @@ case class SortMergeJoinExec(
if (a.nullable) {
val isNull = ctx.freshName("isNull")
val code =
- s"""
+ code"""
|$isNull = $leftRow.isNullAt($i);
|$value = $isNull ? $defaultValue : ($valueCode);
""".stripMargin
@@ -533,7 +534,7 @@ case class SortMergeJoinExec(
(ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)),
leftVarsDecl)
} else {
- val code = s"$value = $valueCode;"
+ val code = code"$value = $valueCode;"
val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
(ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 80769d728b8f1..8e82cccbc8fa3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec(
override def keyExpressions: Seq[Attribute] = groupingAttributes
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ timeoutConf match {
+ case ProcessingTimeTimeout =>
+ true // Always run batches to process timeouts
+ case EventTimeTimeout =>
+ // Process another non-data batch only if the watermark has changed in this executed plan
+ eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
+ case _ =>
+ false
+ }
+ }
+
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
@@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec(
case _ =>
iter
}
-
// Generate a iterator that returns the rows grouped by the grouping function
// Note that this code ensures that the filtering for timeout occurs only after
// all the data has been processed. This is to ensure that the timeout information of all
@@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
- val timingOutKeys = store.getRange(None, None).filter { rowPair =>
+ val timingOutPairs = store.getRange(None, None).filter { rowPair =>
val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
}
- timingOutKeys.flatMap { rowPair =>
+ timingOutPairs.flatMap { rowPair =>
callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
}
} else Iterator.empty
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
index 80aa5505db991..43ad4b3384ec3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
@@ -19,8 +19,8 @@
/**
* This is an internal, deprecated interface. New source implementations should use the
- * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported
- * in the long term.
+ * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be
+ * supported in the long term.
*
* This class will be removed in a future release.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index fa7c8ee906ecd..afa664eb76525 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}
+ override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = {
+ val watermarkUsedForStateCleanup =
+ stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty
+
+ // Latest watermark value is more than that used in this previous executed plan
+ val watermarkHasChanged =
+ eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get
+
+ watermarkUsedForStateCleanup && watermarkHasChanged
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator
val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
@@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec(
// outer join) if possible. In all cases, nothing needs to be outputted, hence the removal
// needs to be done greedily by immediately consuming the returned iterator.
val cleanupIter = joinType match {
- case Inner =>
- leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
+ case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState()
case LeftOuter => rightSideJoiner.removeOldState()
case RightOuter => leftSideJoiner.removeOldState()
case _ => throwBadJoinTypeException()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index f58146ac42398..0e7d1019b9c8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -122,16 +122,7 @@ class ContinuousExecution(
s"Batch $latestEpochId was committed without end epoch offsets!")
}
committedOffsets = nextOffsets.toStreamProgress(sources)
-
- // Get to an epoch ID that has definitely never been sent to a sink before. Since sink
- // commit happens between offset log write and commit log write, this means an epoch ID
- // which is not in the offset log.
- val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
- throw new IllegalStateException(
- s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
- s"an element.")
- }
- currentBatchId = latestOffsetEpoch + 1
+ currentBatchId = latestEpochId + 1
logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
nextOffsets
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
index d8645576c2052..f38577b6a9f16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -46,8 +46,6 @@ class ContinuousQueuedDataReader(
// Important sequencing - we must get our starting point before the provider threads start running
private var currentOffset: PartitionOffset =
ContinuousDataSourceRDD.getContinuousReader(reader).getOffset
- private var currentEpoch: Long =
- context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
/**
* The record types in the read buffer.
@@ -115,8 +113,7 @@ class ContinuousQueuedDataReader(
currentEntry match {
case EpochMarker =>
epochCoordEndpoint.send(ReportPartitionOffset(
- context.partitionId(), currentEpoch, currentOffset))
- currentEpoch += 1
+ context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset))
null
case ContinuousRow(row, offset) =>
currentOffset = offset
@@ -184,7 +181,7 @@ class ContinuousQueuedDataReader(
private val epochCoordEndpoint = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
- // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That
+ // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That
// field represents the epoch wrt the data being processed. The currentEpoch here is just a
// counter to ensure we send the appropriate number of markers if we fall behind the driver.
private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index 8d25d9ccc43d3..516a563bdcc7a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -85,7 +85,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
val start = partitionStartMap(i)
// Have each partition advance by numPartitions each row, with starting points staggered
// by their partition index.
- RateStreamContinuousDataReaderFactory(
+ RateStreamContinuousInputPartition(
start.value,
start.runTimeMs,
i,
@@ -113,7 +113,7 @@ class RateStreamContinuousReader(options: DataSourceOptions)
}
-case class RateStreamContinuousDataReaderFactory(
+case class RateStreamContinuousInputPartition(
startValue: Long,
startTimeMs: Long,
partitionIndex: Int,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
index 91f1576581511..ef5f0da1e7cc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -45,7 +45,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
val epochCoordinator = EpochCoordinatorRef.get(
context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
SparkEnv.get)
- var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+ EpochTracker.initializeCurrentEpoch(
+ context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
while (!context.isInterrupted() && !context.isCompleted()) {
var dataWriter: DataWriter[InternalRow] = null
@@ -54,19 +55,24 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
try {
val dataIterator = prev.compute(split, context)
dataWriter = writeTask.createDataWriter(
- context.partitionId(), context.attemptNumber(), currentEpoch)
+ context.partitionId(),
+ context.attemptNumber(),
+ EpochTracker.getCurrentEpoch.get)
while (dataIterator.hasNext) {
dataWriter.write(dataIterator.next())
}
logInfo(s"Writer for partition ${context.partitionId()} " +
- s"in epoch $currentEpoch is committing.")
+ s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.")
val msg = dataWriter.commit()
epochCoordinator.send(
- CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
+ CommitPartitionEpoch(
+ context.partitionId(),
+ EpochTracker.getCurrentEpoch.get,
+ msg)
)
logInfo(s"Writer for partition ${context.partitionId()} " +
- s"in epoch $currentEpoch committed.")
- currentEpoch += 1
+ s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.")
+ EpochTracker.incrementCurrentEpoch()
} catch {
case _: InterruptedException =>
// Continuous shutdown always involves an interrupt. Just finish the task.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index cc6808065c0cd..8877ebeb26735 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator(
private val partitionOffsets =
mutable.Map[(Long, Int), PartitionOffset]()
+ private var lastCommittedEpoch = startEpoch - 1
+ // Remembers epochs that have to wait for previous epochs to be committed first.
+ private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long]
+
private def resolveCommitsAtEpoch(epoch: Long) = {
- val thisEpochCommits =
- partitionCommits.collect { case ((e, _), msg) if e == epoch => msg }
+ val thisEpochCommits = findPartitionCommitsForEpoch(epoch)
val nextEpochOffsets =
partitionOffsets.collect { case ((e, _), o) if e == epoch => o }
if (thisEpochCommits.size == numWriterPartitions &&
nextEpochOffsets.size == numReaderPartitions) {
- logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.")
- // Sequencing is important here. We must commit to the writer before recording the commit
- // in the query, or we will end up dropping the commit if we restart in the middle.
- writer.commit(epoch, thisEpochCommits.toArray)
- query.commit(epoch)
-
- // Cleanup state from before this epoch, now that we know all partitions are forever past it.
- for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) {
- partitionCommits.remove(k)
- }
- for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
- partitionOffsets.remove(k)
+
+ // Check that last committed epoch is the previous one for sequencing of committed epochs.
+ // If not, add the epoch being currently processed to epochs waiting to be committed,
+ // otherwise commit it.
+ if (lastCommittedEpoch != epoch - 1) {
+ logDebug(s"Epoch $epoch has received commits from all partitions " +
+ s"and is waiting for epoch ${epoch - 1} to be committed first.")
+ epochsWaitingToBeCommitted.add(epoch)
+ } else {
+ commitEpoch(epoch, thisEpochCommits)
+ lastCommittedEpoch = epoch
+
+ // Commit subsequent epochs that are waiting to be committed.
+ var nextEpoch = lastCommittedEpoch + 1
+ while (epochsWaitingToBeCommitted.contains(nextEpoch)) {
+ val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch)
+ commitEpoch(nextEpoch, nextEpochCommits)
+
+ epochsWaitingToBeCommitted.remove(nextEpoch)
+ lastCommittedEpoch = nextEpoch
+ nextEpoch += 1
+ }
+
+ // Cleanup state from before last committed epoch,
+ // now that we know all partitions are forever past it.
+ for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) {
+ partitionCommits.remove(k)
+ }
+ for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) {
+ partitionOffsets.remove(k)
+ }
}
}
}
+ /**
+ * Collect per-partition commits for an epoch.
+ */
+ private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = {
+ partitionCommits.collect { case ((e, _), msg) if e == epoch => msg }
+ }
+
+ /**
+ * Commit epoch to the offset log.
+ */
+ private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = {
+ logDebug(s"Epoch $epoch has received commits from all partitions " +
+ s"and is ready to be committed. Committing epoch $epoch.")
+ // Sequencing is important here. We must commit to the writer before recording the commit
+ // in the query, or we will end up dropping the commit if we restart in the middle.
+ writer.commit(epoch, messages.toArray)
+ query.commit(epoch)
+ }
+
override def receive: PartialFunction[Any, Unit] = {
// If we just drop these messages, we won't do any writes to the query. The lame duck tasks
// won't shed errors or anything.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala
new file mode 100644
index 0000000000000..bc0ae428d4521
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+/**
+ * Tracks the current continuous processing epoch within a task. Call
+ * EpochTracker.getCurrentEpoch to get the current epoch.
+ */
+object EpochTracker {
+ // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will
+ // update the underlying AtomicLong as it finishes epochs. Other code should only read the value.
+ private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] {
+ override def initialValue() = new AtomicLong(-1)
+ }
+
+ /**
+ * Get the current epoch for the current task, or None if the task has no current epoch.
+ */
+ def getCurrentEpoch: Option[Long] = {
+ currentEpoch.get().get() match {
+ case n if n < 0 => None
+ case e => Some(e)
+ }
+ }
+
+ /**
+ * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]]
+ * between epochs.
+ */
+ def incrementCurrentEpoch(): Unit = {
+ currentEpoch.get().incrementAndGet()
+ }
+
+ /**
+ * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]]
+ * at the beginning of a task.
+ */
+ def initializeCurrentEpoch(startEpoch: Long): Unit = {
+ currentEpoch.get().set(startEpoch)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
new file mode 100644
index 0000000000000..270b1a5c28dee
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.UUID
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.NextIterator
+
+case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition {
+ // Initialized only on the executor, and only once even as we call compute() multiple times.
+ lazy val (reader: ContinuousShuffleReader, endpoint) = {
+ val env = SparkEnv.get.rpcEnv
+ val receiver = new UnsafeRowReceiver(queueSize, env)
+ val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
+ TaskContext.get().addTaskCompletionListener { ctx =>
+ env.stop(endpoint)
+ }
+ (receiver, endpoint)
+ }
+}
+
+/**
+ * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their
+ * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks
+ * poll from their receiver until an epoch marker is sent.
+ */
+class ContinuousShuffleReadRDD(
+ sc: SparkContext,
+ numPartitions: Int,
+ queueSize: Int = 1024)
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ override protected def getPartitions: Array[Partition] = {
+ (0 until numPartitions).map { partIndex =>
+ ContinuousShuffleReadPartition(partIndex, queueSize)
+ }.toArray
+ }
+
+ override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ split.asInstanceOf[ContinuousShuffleReadPartition].reader.read()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
new file mode 100644
index 0000000000000..42631c90ebc55
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * Trait for reading from a continuous processing shuffle.
+ */
+trait ContinuousShuffleReader {
+ /**
+ * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting
+ * for new rows to arrive, and end the iterator once they've received epoch markers from all
+ * shuffle writers.
+ */
+ def read(): Iterator[UnsafeRow]
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
new file mode 100644
index 0000000000000..b8adbb743c6c2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.NextIterator
+
+/**
+ * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker.
+ */
+private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable
+private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage
+private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage
+
+/**
+ * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle
+ * writers will send rows here, with continuous shuffle readers polling for new rows as needed.
+ *
+ * TODO: Support multiple source tasks. We need to output a single epoch marker once all
+ * source tasks have sent one.
+ */
+private[shuffle] class UnsafeRowReceiver(
+ queueSize: Int,
+ override val rpcEnv: RpcEnv)
+ extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
+ // Note that this queue will be drained from the main task thread and populated in the RPC
+ // response thread.
+ private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
+
+ // Exposed for testing to determine if the endpoint gets stopped on task end.
+ private[shuffle] val stopped = new AtomicBoolean(false)
+
+ override def onStop(): Unit = {
+ stopped.set(true)
+ }
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case r: UnsafeRowReceiverMessage =>
+ queue.put(r)
+ context.reply(())
+ }
+
+ override def read(): Iterator[UnsafeRow] = {
+ new NextIterator[UnsafeRow] {
+ override def getNext(): UnsafeRow = queue.take() match {
+ case ReceiverRow(r) => r
+ case ReceiverEpochMarker() =>
+ finished = true
+ null
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index daa2963220aef..b137f98045c5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
newBlocks.map { block =>
- new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]]
+ new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]]
}.asJava
}
}
@@ -201,7 +201,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
-class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
+class MemoryStreamInputPartition(records: Array[UnsafeRow])
extends InputPartition[UnsafeRow] {
override def createPartitionReader(): InputPartitionReader[UnsafeRow] = {
new InputPartitionReader[UnsafeRow] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index fef792eab69d5..d1c3498450096 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -44,13 +44,12 @@ import org.apache.spark.util.RpcUtils
* * ContinuousMemoryStream maintains a list of records for each partition. addData() will
* distribute records evenly-ish across partitions.
* * RecordEndpoint is set up as an endpoint for executor-side
- * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified
- * offset within the list, or null if that offset doesn't yet have a record.
+ * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at
+ * the specified offset within the list, or null if that offset doesn't yet have a record.
*/
-class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
+class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport {
private implicit val formats = Serialization.formats(NoTypeHints)
- private val NUM_PARTITIONS = 2
protected val logicalPlan =
StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession)
@@ -58,7 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
// ContinuousReader implementation
@GuardedBy("this")
- private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])
+ private val records = Seq.fill(numPartitions)(new ListBuffer[A])
@GuardedBy("this")
private var startOffset: ContinuousMemoryStreamOffset = _
@@ -69,17 +68,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
def addData(data: TraversableOnce[A]): Offset = synchronized {
// Distribute data evenly among partition lists.
data.toSeq.zipWithIndex.map {
- case (item, index) => records(index % NUM_PARTITIONS) += item
+ case (item, index) => records(index % numPartitions) += item
}
// The new target offset is the offset where all records in all partitions have been processed.
- ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap)
+ ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap)
}
override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
// Inferred initial offset is position 0 in each partition.
startOffset = start.orElse {
- ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap)
+ ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap)
}.asInstanceOf[ContinuousMemoryStreamOffset]
}
@@ -107,7 +106,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
startOffset.partitionNums.map {
case (part, index) =>
- new ContinuousMemoryStreamDataReaderFactory(
+ new ContinuousMemoryStreamInputPartition(
endpointName, part, index): InputPartition[Row]
}.toList.asJava
}
@@ -152,12 +151,15 @@ object ContinuousMemoryStream {
def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
+
+ def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
}
/**
- * Data reader factory for continuous memory stream.
+ * An input partition for continuous memory stream.
*/
-class ContinuousMemoryStreamDataReaderFactory(
+class ContinuousMemoryStreamInputPartition(
driverEndpointName: String,
partition: Int,
startOffset: Int) extends InputPartition[Row] {
@@ -166,7 +168,7 @@ class ContinuousMemoryStreamDataReaderFactory(
}
/**
- * Data reader for continuous memory stream.
+ * An input partition reader for continuous memory stream.
*
* Polls the driver endpoint for new records.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
index 723cc3ad5bb89..fbff8db987110 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala
@@ -167,7 +167,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
}
(0 until numPartitions).map { p =>
- new RateStreamMicroBatchDataReaderFactory(
+ new RateStreamMicroBatchInputPartition(
p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue)
: InputPartition[Row]
}.toList.asJava
@@ -182,7 +182,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation:
s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
}
-class RateStreamMicroBatchDataReaderFactory(
+class RateStreamMicroBatchInputPartition(
partitionId: Int,
numPartitions: Int,
rangeStart: Long,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index 01d8e75980993..3f11b8f79943c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.execution.streaming.continuous.EpochTracker
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration
@@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
StateStoreId(checkpointLocation, operatorId, partition.index),
queryRunId)
+ // If we're in continuous processing mode, we should get the store version for the current
+ // epoch rather than the one at planning time.
+ val currentVersion = EpochTracker.getCurrentEpoch match {
+ case None => storeVersion
+ case Some(value) => value
+ }
+
store = StateStore.get(
- storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
+ storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion,
storeConf, hadoopConfBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
index 582528777f90e..bf46bc4cf904d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
_content ++=
new RunningExecutionTable(
parent, s"Running Queries (${running.size})", currentTime,
- running.sortBy(_.submissionTime).reverse).toNodeSeq
+ running.sortBy(_.submissionTime).reverse).toNodeSeq(request)
}
if (completed.nonEmpty) {
_content ++=
new CompletedExecutionTable(
parent, s"Completed Queries (${completed.size})", currentTime,
- completed.sortBy(_.submissionTime).reverse).toNodeSeq
+ completed.sortBy(_.submissionTime).reverse).toNodeSeq(request)
}
if (failed.nonEmpty) {
_content ++=
new FailedExecutionTable(
parent, s"Failed Queries (${failed.size})", currentTime,
- failed.sortBy(_.submissionTime).reverse).toNodeSeq
+ failed.sortBy(_.submissionTime).reverse).toNodeSeq(request)
}
_content
}
@@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L
}
- UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000))
+ UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000))
}
}
@@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable(
protected def header: Seq[String]
- protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = {
+ protected def row(
+ request: HttpServletRequest,
+ currentTime: Long,
+ executionUIData: SQLExecutionUIData): Seq[Node] = {
val submissionTime = executionUIData.submissionTime
val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) -
submissionTime
@@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable(
def jobLinks(status: JobExecutionStatus): Seq[Node] = {
executionUIData.jobs.flatMap { case (jobId, jobStatus) =>
if (jobStatus == status) {
- [{jobId.toString}]
+ [{jobId.toString}]
} else {
None
}
@@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable(
{executionUIData.executionId.toString}
|
- {descriptionCell(executionUIData)}
+ {descriptionCell(request, executionUIData)}
|
{UIUtils.formatDate(submissionTime)}
@@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable(
|
}
- private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = {
+ private def descriptionCell(
+ request: HttpServletRequest,
+ execution: SQLExecutionUIData): Seq[Node] = {
val details = if (execution.details != null && execution.details.nonEmpty) {