diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 855eb5bf77f16..f52d785e05cdd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -13,6 +13,7 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ BugReports: http://spark.apache.org/contributing.html +SystemRequirements: Java (== 8) Depends: R (>= 3.0), methods diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5f8209689a559..c575fe255f57a 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -352,6 +352,7 @@ exportMethods("%<=>%", "sinh", "size", "skewness", + "slice", "sort_array", "soundex", "spark_partition_id", diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 7244cc9f9e38e..4c87f64e7f0e1 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -60,6 +60,41 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack combinedArgs } +checkJavaVersion <- function() { + javaBin <- "java" + javaHome <- Sys.getenv("JAVA_HOME") + javaReqs <- utils::packageDescription(utils::packageName(), fields = c("SystemRequirements")) + sparkJavaVersion <- as.numeric(tail(strsplit(javaReqs, "[(=)]")[[1]], n = 1L)) + if (javaHome != "") { + javaBin <- file.path(javaHome, "bin", javaBin) + } + + # If java is missing from PATH, we get an error in Unix and a warning in Windows + javaVersionOut <- tryCatch( + launchScript(javaBin, "-version", wait = TRUE, stdout = TRUE, stderr = TRUE), + error = function(e) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", e) + }, + warning = function(w) { + stop("Java version check failed. Please make sure Java is installed", + " and set JAVA_HOME to point to the installation directory.", w) + }) + javaVersionFilter <- Filter( + function(x) { + grepl(" version", x) + }, javaVersionOut) + + javaVersionStr <- strsplit(javaVersionFilter[[1]], "[\"]")[[1L]][2] + # javaVersionStr is of the form 1.8.0_92. + # Extract 8 from it to compare to sparkJavaVersion + javaVersionNum <- as.integer(strsplit(javaVersionStr, "[.]")[[1L]][2]) + if (javaVersionNum != sparkJavaVersion) { + stop(paste("Java version", sparkJavaVersion, "is required for this package; found version:", + javaVersionStr)) + } +} + launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { @@ -67,6 +102,7 @@ launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { } else { sparkSubmitBin <- sparkSubmitBinName } + combinedArgs <- generateSparkSubmitArgs(args, sparkHome, jars, sparkSubmitOpts, packages) cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n") invisible(launchScript(sparkSubmitBin, combinedArgs)) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 1f97054443e1b..fcb3521f901ea 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -208,16 +208,20 @@ NULL #' head(select(tmp, array_contains(tmp$v1, 21), size(tmp$v1))) #' head(select(tmp, array_max(tmp$v1), array_min(tmp$v1))) #' head(select(tmp, array_position(tmp$v1, 21), array_sort(tmp$v1))) -#' head(select(tmp, flatten(tmp$v1))) +#' head(select(tmp, flatten(tmp$v1), reverse(tmp$v1))) #' tmp2 <- mutate(tmp, v2 = explode(tmp$v1)) #' head(tmp2) #' head(select(tmp, posexplode(tmp$v1))) +#' head(select(tmp, slice(tmp$v1, 2L, 2L))) #' head(select(tmp, sort_array(tmp$v1))) #' head(select(tmp, sort_array(tmp$v1, asc = FALSE))) #' tmp3 <- mutate(df, v3 = create_map(df$model, df$cyl)) #' head(select(tmp3, map_keys(tmp3$v3))) #' head(select(tmp3, map_values(tmp3$v3))) -#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))} +#' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) +#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$hp)) +#' head(select(tmp4, concat(tmp4$v4, tmp4$v5))) +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} NULL #' Window functions for Column operations @@ -1259,9 +1263,9 @@ setMethod("quarter", }) #' @details -#' \code{reverse}: Reverses the string column and returns it as a new string column. +#' \code{reverse}: Returns a reversed string or an array with reverse order of elements. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases reverse reverse,Column-method #' @note reverse since 1.5.0 setMethod("reverse", @@ -1912,6 +1916,7 @@ setMethod("atan2", signature(y = "Column"), #' @details #' \code{datediff}: Returns the number of days from \code{y} to \code{x}. +#' If \code{y} is later than \code{x} then the result is positive. #' #' @rdname column_datetime_diff_functions #' @aliases datediff datediff,Column-method @@ -1971,7 +1976,10 @@ setMethod("levenshtein", signature(y = "Column"), }) #' @details -#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' \code{months_between}: Returns number of months between dates \code{y} and \code{x}. +#' If \code{y} is later than \code{x}, then the result is positive. If \code{y} and \code{x} +#' are on the same day of month, or both are the last day of month, time of day will be ignored. +#' Otherwise, the difference is calculated based on 31 days per month, and rounded to 8 digits. #' #' @rdname column_datetime_diff_functions #' @aliases months_between months_between,Column-method @@ -2050,20 +2058,10 @@ setMethod("countDistinct", #' @details #' \code{concat}: Concatenates multiple input columns together into a single column. -#' If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string. +#' The function works with strings, binary and compatible array columns. #' -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @aliases concat concat,Column-method -#' @examples -#' -#' \dontrun{ -#' # concatenate strings -#' tmp <- mutate(df, s1 = concat(df$Class, df$Sex), -#' s2 = concat(df$Class, df$Sex, df$Age), -#' s3 = concat(df$Class, df$Sex, df$Age, df$Class), -#' s4 = concat_ws("_", df$Class, df$Sex), -#' s5 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) -#' head(tmp)} #' @note concat since 1.5.0 setMethod("concat", signature(x = "Column"), @@ -2404,6 +2402,13 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @rdname column_string_functions #' @aliases concat_ws concat_ws,character,Column-method +#' @examples +#' +#' \dontrun{ +#' # concatenate strings +#' tmp <- mutate(df, s1 = concat_ws("_", df$Class, df$Sex), +#' s2 = concat_ws("+", df$Class, df$Sex, df$Age, df$Survived)) +#' head(tmp)} #' @note concat_ws since 1.5.0 setMethod("concat_ws", signature(sep = "character", x = "Column"), function(sep, x, ...) { @@ -3058,7 +3063,8 @@ setMethod("array_sort", }) #' @details -#' \code{flatten}: Transforms an array of arrays into a single array. +#' \code{flatten}: Creates a single array from an array of arrays. +#' If a structure of nested arrays is deeper than two levels, only one level of nesting is removed. #' #' @rdname column_collection_functions #' @aliases flatten flatten,Column-method @@ -3138,6 +3144,22 @@ setMethod("size", column(jc) }) +#' @details +#' \code{slice}: Returns an array containing all the elements in x from the index start +#' (or starting from the end if start is negative) with the specified length. +#' +#' @rdname column_collection_functions +#' @param start an index indicating the first element occuring in the result. +#' @param length a number of consecutive elements choosen to the result. +#' @aliases slice slice,Column-method +#' @note slice since 2.4.0 +setMethod("slice", + signature(x = "Column"), + function(x, start, length) { + jc <- callJStatic("org.apache.spark.sql.functions", "slice", x@jc, start, length) + column(jc) + }) + #' @details #' \code{sort_array}: Sorts the input array in ascending or descending order according to #' the natural ordering of the array elements. NA elements will be placed at the beginning of diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5faa51eef3abd..3ea181157b644 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -624,7 +624,7 @@ setGeneric("summarize", function(x, ...) { standardGeneric("summarize") }) #' @rdname summary setGeneric("summary", function(object, ...) { standardGeneric("summary") }) -setGeneric("toJSON", function(x) { standardGeneric("toJSON") }) +setGeneric("toJSON", function(x, ...) { standardGeneric("toJSON") }) setGeneric("toRDD", function(x) { standardGeneric("toRDD") }) @@ -817,7 +817,7 @@ setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) #' @rdname column setGeneric("column", function(x) { standardGeneric("column") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("concat", function(x, ...) { standardGeneric("concat") }) @@ -1134,7 +1134,7 @@ setGeneric("regexp_replace", #' @name NULL setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) -#' @rdname column_string_functions +#' @rdname column_collection_functions #' @name NULL setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1194,6 +1194,10 @@ setGeneric("size", function(x) { standardGeneric("size") }) #' @name NULL setGeneric("skewness", function(x) { standardGeneric("skewness") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("slice", function(x, start, length) { standardGeneric("slice") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 38ee79477996f..f7c1663d32c96 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -167,6 +167,7 @@ sparkR.sparkContext <- function( submitOps <- getClientModeSparkSubmitOpts( Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), sparkEnvirMap) + checkJavaVersion() launchBackend( args = path, sparkHome = sparkHome, @@ -193,7 +194,7 @@ sparkR.sparkContext <- function( # Don't use readString() so that we can provide a useful # error message if the R and Java versions are mismatched. - authSecretLen = readInt(f) + authSecretLen <- readInt(f) if (length(authSecretLen) == 0 || authSecretLen == 0) { stop("Unexpected EOF in JVM connection data. Mismatched versions?") } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index f1b5ecaa017df..c3501977e64bc 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -746,7 +746,7 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, wait = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr = "") { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) @@ -756,7 +756,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE) { # stdout = F means discard output # stdout = "" means to its console (default) # Note that the console of this child process might not be the same as the running R process. - system2(script, combinedArgs, stdout = "", wait = wait) + system2(script, combinedArgs, stdout = stdout, wait = wait, stderr = stderr) } } diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index b8bfded0ebf2d..13b55ac6e6e3c 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1479,7 +1479,7 @@ test_that("column functions", { df5 <- createDataFrame(list(list(a = "010101"))) expect_equal(collect(select(df5, conv(df5$a, 2, 16)))[1, 1], "15") - # Test array_contains(), array_max(), array_min(), array_position() and element_at() + # Test array_contains(), array_max(), array_min(), array_position(), element_at() and reverse() df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(6L, 5L, 4L)))) result <- collect(select(df, array_contains(df[[1]], 1L)))[[1]] expect_equal(result, c(TRUE, FALSE)) @@ -1496,6 +1496,13 @@ test_that("column functions", { result <- collect(select(df, element_at(df[[1]], 1L)))[[1]] expect_equal(result, c(1, 6)) + result <- collect(select(df, reverse(df[[1]])))[[1]] + expect_equal(result, list(list(3L, 2L, 1L), list(4L, 5L, 6L))) + + df2 <- createDataFrame(list(list("abc"))) + result <- collect(select(df2, reverse(df2[[1]])))[[1]] + expect_equal(result, "cba") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L)))) @@ -1507,7 +1514,18 @@ test_that("column functions", { result <- collect(select(df, sort_array(df[[1]])))[[1]] expect_equal(result, list(list(NA, 1L, 2L, 3L), list(NA, NA, 4L, 5L, 6L))) - # Test flattern + # Test slice() + df <- createDataFrame(list(list(list(1L, 2L, 3L)), list(list(4L, 5L)))) + result <- collect(select(df, slice(df[[1]], 2L, 2L)))[[1]] + expect_equal(result, list(list(2L, 3L), list(5L))) + + # Test concat() + df <- createDataFrame(list(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), + list(list(7L, 8L, 9L), list(10L, 11L, 12L)))) + result <- collect(select(df, concat(df[[1]], df[[2]])))[[1]] + expect_equal(result, list(list(1L, 2L, 3L, 4L, 5L, 6L), list(7L, 8L, 9L, 10L, 11L, 12L))) + + # Test flatten() df <- createDataFrame(list(list(list(list(1L, 2L), list(3L, 4L))), list(list(list(5L, 6L), list(7L, 8L))))) result <- collect(select(df, flatten(df[[1]])))[[1]] diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index a76283e33fa65..33901bc8380e9 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case class TaskKilled(reason: String) extends TaskFailedReason { +case class TaskKilled( + reason: String, + accumUpdates: Seq[AccumulableInfo] = Seq.empty, + private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil) + extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false + } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 087e9c31a9c9a..4baf032f0e9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -310,6 +310,7 @@ private[spark] class SparkSubmit extends Logging { val isMesosCluster = clusterManager == MESOS && deployMode == CLUSTER val isStandAloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER val isKubernetesCluster = clusterManager == KUBERNETES && deployMode == CLUSTER + val isMesosClient = clusterManager == MESOS && deployMode == CLIENT if (!isMesosCluster && !isStandAloneCluster) { // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files @@ -337,7 +338,7 @@ private[spark] class SparkSubmit extends Logging { val targetDir = Utils.createTempDir() // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (clusterManager == YARN || clusterManager == LOCAL || isMesosClient) { if (args.principal != null) { if (args.keytab != null) { require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0733fdb72cafb..fed4e0a5069c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -36,7 +36,6 @@ import org.apache.spark.launcher.SparkSubmitArgumentsParser import org.apache.spark.network.util.JavaUtils import org.apache.spark.util.Utils - /** * Parses and encapsulates arguments from the spark-submit script. * The env argument is used for testing. @@ -76,6 +75,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var proxyUser: String = null var principal: String = null var keytab: String = null + private var dynamicAllocationEnabled: Boolean = false // Standalone cluster mode only var supervise: Boolean = false @@ -198,6 +198,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull + dynamicAllocationEnabled = + sparkProperties.get("spark.dynamicAllocation.enabled").exists("true".equalsIgnoreCase) // Try to set main class from JAR if no --class argument is given if (mainClass == null && !isPython && !isR && primaryResource != null) { @@ -274,7 +276,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S if (totalExecutorCores != null && Try(totalExecutorCores.toInt).getOrElse(-1) <= 0) { error("Total executor cores must be a positive number") } - if (numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { + if (!dynamicAllocationEnabled && + numExecutors != null && Try(numExecutors.toInt).getOrElse(-1) <= 0) { error("Number of executors must be a positive number") } if (pyFiles != null && !isPython) { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 6fc12d721e6f1..32667ddf5c7ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,8 +37,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = - ++ - + ++ +
- UIUtils.basicSparkPage(content, "History Server", true) + UIUtils.basicSparkPage(request, content, "History Server", true) } - private def makePageLink(showIncomplete: Boolean): String = { - UIUtils.prependBaseUri("/?" + "showIncomplete=" + showIncomplete) + private def makePageLink(request: HttpServletRequest, showIncomplete: Boolean): String = { + UIUtils.prependBaseUri(request, "/?" + "showIncomplete=" + showIncomplete) } private def isApplicationCompleted(appInfo: ApplicationInfo): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 611fa563a7cd9..a9a4d5a4ec6a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -87,7 +87,7 @@ class HistoryServer( if (!loadAppUi(appId, None) && (!attemptId.isDefined || !loadAppUi(appId, attemptId))) { val msg =
Application {appId} not found.
res.setStatus(HttpServletResponse.SC_NOT_FOUND) - UIUtils.basicSparkPage(msg, "Not Found").foreach { n => + UIUtils.basicSparkPage(req, msg, "Not Found").foreach { n => res.getWriter().write(n.toString) } return diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f699c75085fe1..fad4e46dc035d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -40,7 +40,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
No running application with ID {appId}
- return UIUtils.basicSparkPage(msg, "Not Found") + return UIUtils.basicSparkPage(request, msg, "Not Found") } val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs") @@ -127,7 +127,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") } ; - UIUtils.basicSparkPage(content, "Application: " + app.desc.name) + UIUtils.basicSparkPage(request, content, "Application: " + app.desc.name) } private def executorRow(executor: ExecutorDesc): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index c629937606b51..b8afe203fbfa2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -215,7 +215,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) + UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } private def workerRow(worker: WorkerInfo): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 2f5a5642d3cab..4fca9342c0378 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -118,7 +118,7 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with - UIUtils.basicSparkPage(content, logType + " log page for " + pageName) + UIUtils.basicSparkPage(request, content, logType + " log page for " + pageName) } /** Get the part of the log files given the offset and desired length of bytes */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8b98ae56fc108..aa4e28d213e2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -135,7 +135,7 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { } ; - UIUtils.basicSparkPage(content, "Spark Worker at %s:%s".format( + UIUtils.basicSparkPage(request, content, "Spark Worker at %s:%s".format( workerState.host, workerState.port)) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c325222b764b8..b1856ff0f3247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -287,6 +287,28 @@ private[spark] class Executor( notifyAll() } + /** + * Utility function to: + * 1. Report executor runtime and JVM gc time if possible + * 2. Collect accumulator updates + * 3. Set the finished flag to true and clear current thread's interrupt status + */ + private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = { + // Report executor runtime and JVM gc time + Option(task).foreach(t => { + t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime) + t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + }) + + // Collect latest accumulator values to report back to the driver + val accums: Seq[AccumulatorV2[_, _]] = + Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty) + val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + + setTaskFinishedAndClearInterruptStatus() + (accums, accUpdates) + } + override def run(): Unit = { threadId = Thread.currentThread.getId Thread.currentThread.setName(threadName) @@ -300,7 +322,7 @@ private[spark] class Executor( val ser = env.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) - var taskStart: Long = 0 + var taskStartTime: Long = 0 var taskStartCpu: Long = 0 startGCTime = computeTotalGcTime() @@ -336,7 +358,7 @@ private[spark] class Executor( } // Run the actual task and measure its runtime. - taskStart = System.currentTimeMillis() + taskStartTime = System.currentTimeMillis() taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L @@ -396,11 +418,11 @@ private[spark] class Executor( // Deserialization happens in two parts: first, we deserialize a Task object, which // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. task.metrics.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) + (taskStartTime - deserializeStartTime) + task.executorDeserializeTime) task.metrics.setExecutorDeserializeCpuTime( (taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime) // We need to subtract Task.run()'s deserialization time to avoid double-counting - task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime) task.metrics.setExecutorCpuTime( (taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime) task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) @@ -482,16 +504,19 @@ private[spark] class Executor( } catch { case t: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case _: InterruptedException | NonFatal(_) if task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") - setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate( - taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) + + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) + val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums)) + execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => val reason = task.context.fetchFailed.get.toTaskFailedReason @@ -524,17 +549,7 @@ private[spark] class Executor( // the task failure would not be ignored if the shutdown happened because of premption, // instead of an app issue). if (!ShutdownHookManager.inShutdown()) { - // Collect latest accumulator values to report back to the driver - val accums: Seq[AccumulatorV2[_, _]] = - if (task != null) { - task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) - task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) - task.collectAccumulatorUpdates(taskFailed = true) - } else { - Seq.empty - } - - val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None)) + val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime) val serializedTaskEndReason = { try { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 82f0a04e94b1c..a54b091a64d50 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -342,7 +342,7 @@ package object config { "a property key or value, the value is redacted from the environment UI and various logs " + "like YARN and event logs.") .regexConf - .createWithDefault("(?i)secret|password|url|user|username".r) + .createWithDefault("(?i)secret|password".r) private[spark] val STRING_REDACTION_PATTERN = ConfigBuilder("spark.redaction.string.regex") diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index c1fedd63f6a90..e2b6df4600590 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -34,7 +34,11 @@ import org.apache.spark.util.Utils * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. */ -private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) +private class AsyncEventQueue( + val name: String, + conf: SparkConf, + metrics: LiveListenerBusMetrics, + bus: LiveListenerBus) extends SparkListenerBus with Logging { @@ -81,23 +85,18 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - try { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() - } - eventCount.decrementAndGet() - next = eventQueue.take() + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - } catch { - case ie: InterruptedException => - logInfo(s"Stopping listener queue $name.", ie) + next = eventQueue.take() } + eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { @@ -130,7 +129,11 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi eventCount.incrementAndGet() eventQueue.put(POISON_PILL) } - dispatchThread.join() + // this thread might be trying to stop itself as part of error handling -- we can't join + // in that case. + if (Thread.currentThread() != dispatchThread) { + dispatchThread.join() + } } def post(event: SparkListenerEvent): Unit = { @@ -187,6 +190,12 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } + override def removeListenerOnError(listener: SparkListenerInterface): Unit = { + // the listener failed in an unrecoverably way, we want to remove it from the entire + // LiveListenerBus (potentially stopping a queue if it is empty) + bus.removeListener(listener) + } + } private object AsyncEventQueue { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 78b6b34b5d2bb..ea7bfd7d7a68d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -206,7 +206,7 @@ class DAGScheduler( private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") - private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) + private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) /** @@ -1210,7 +1210,7 @@ class DAGScheduler( case _ => updateAccumulators(event) } - case _: ExceptionFailure => updateAccumulators(event) + case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event) case _ => } postTaskEnd(event) @@ -1414,13 +1414,13 @@ class DAGScheduler( case commitDenied: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits - case exceptionFailure: ExceptionFailure => + case _: ExceptionFailure | _: TaskKilled => // Nothing left to do, already handled above for accumulator updates. case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => + case _: ExecutorLostFailure | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index ba6387a8f08ad..d135190d1e919 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -102,7 +102,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) { queue.addListener(listener) case None => - val newQueue = new AsyncEventQueue(queue, conf, metrics) + val newQueue = new AsyncEventQueue(queue, conf, metrics, this) newQueue.addListener(listener) if (started.get()) { newQueue.start(sparkContext) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 195fc8025e4b5..a18c66596852a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -851,13 +851,19 @@ private[spark] class TaskSetManager( } ef.exception + case tk: TaskKilled => + // TaskKilled might have accumulator updates + accumUpdates = tk.accums + logWarning(failureReason) + None + case e: ExecutorLostFailure if !e.exitCausedByApp => logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None - case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others + case e: TaskFailedReason => // TaskResultLost and others logWarning(failureReason) None } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 02cf19e00ecde..5d015b0531ef6 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui import java.net.URLDecoder import java.text.SimpleDateFormat import java.util.{Date, Locale, TimeZone} +import javax.servlet.http.HttpServletRequest import scala.util.control.NonFatal import scala.xml._ @@ -148,60 +149,71 @@ private[spark] object UIUtils extends Logging { } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - def uiRoot: String = { + def uiRoot(request: HttpServletRequest): String = { + // Knox uses X-Forwarded-Context to notify the application the base path + val knoxBasePath = Option(request.getHeader("X-Forwarded-Context")) // SPARK-11484 - Use the proxyBase set by the AM, if not found then use env. sys.props.get("spark.ui.proxyBase") .orElse(sys.env.get("APPLICATION_WEB_PROXY_BASE")) + .orElse(knoxBasePath) .getOrElse("") } - def prependBaseUri(basePath: String = "", resource: String = ""): String = { - uiRoot + basePath + resource + def prependBaseUri( + request: HttpServletRequest, + basePath: String = "", + resource: String = ""): String = { + uiRoot(request) + basePath + resource } - def commonHeaderNodes: Seq[Node] = { + def commonHeaderNodes(request: HttpServletRequest): Seq[Node] = { - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + } - def vizHeaderNodes: Seq[Node] = { - - - - - + def vizHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + + + + } - def dataTablesHeaderNodes: Seq[Node] = { + def dataTablesHeaderNodes(request: HttpServletRequest): Seq[Node] = { + + href={prependBaseUri(request, "/static/dataTables.bootstrap.css")} type="text/css"/> - - - - - - - + href={prependBaseUri(request, "/static/jsonFormatter.min.css")} type="text/css"/> + + + + + + } /** Returns a spark page with correctly formatted headers */ def headerSparkPage( + request: HttpServletRequest, title: String, content: => Seq[Node], activeTab: SparkUITab, @@ -214,25 +226,26 @@ private[spark] object UIUtils extends Logging { val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } val helpButton: Seq[Node] = helpText.map(tooltip(_, "bottom")).getOrElse(Seq.empty) - {commonHeaderNodes} - {if (showVisualization) vizHeaderNodes else Seq.empty} - {if (useDataTables) dataTablesHeaderNodes else Seq.empty} - + {commonHeaderNodes(request)} + {if (showVisualization) vizHeaderNodes(request) else Seq.empty} + {if (useDataTables) dataTablesHeaderNodes(request) else Seq.empty} + {appName} - {title} } - UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent, showVisualization = true) + UIUtils.headerSparkPage( + request, s"Details for Job $jobId", content, parent, showVisualization = true) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index a3e1f13782e30..22a40101e33df 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -49,7 +49,7 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { "stages/pool", parent.isFairScheduler, parent.killEnabled, false) val poolTable = new PoolTable(Map(pool -> uiPool), parent) - var content =

    Summary

    ++ poolTable.toNodeSeq + var content =

    Summary

    ++ poolTable.toNodeSeq(request) if (activeStages.nonEmpty) { content ++= } - UIUtils.headerSparkPage("Fair Scheduler Pool: " + poolName, content, parent) + UIUtils.headerSparkPage(request, "Fair Scheduler Pool: " + poolName, content, parent) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 5dfce858dec07..96b5f72393070 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.jobs import java.net.URLEncoder +import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -28,7 +29,7 @@ import org.apache.spark.ui.UIUtils /** Table showing list of pools */ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab) { - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = { @@ -39,15 +40,15 @@ private[ui] class PoolTable(pools: Map[Schedulable, PoolData], parent: StagesTab - {pools.map { case (s, p) => poolRow(s, p) }} + {pools.map { case (s, p) => poolRow(request, s, p) }}
    Pool NameSchedulingMode
    } - private def poolRow(s: Schedulable, p: PoolData): Seq[Node] = { + private def poolRow(request: HttpServletRequest, s: Schedulable, p: PoolData): Seq[Node] = { val activeStages = p.stageIds.size val href = "%s/stages/pool?poolname=%s" - .format(UIUtils.prependBaseUri(parent.basePath), URLEncoder.encode(p.name, "UTF-8")) + .format(UIUtils.prependBaseUri(request, parent.basePath), URLEncoder.encode(p.name, "UTF-8")) {p.name} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index ac83de10f9237..2575914121c39 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -112,7 +112,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    No information to display for Stage {stageId} (Attempt {stageAttemptId})

    - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val localitySummary = store.localitySummary(stageData.stageId, stageData.attemptId) @@ -125,7 +125,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet - return UIUtils.headerSparkPage(stageHeader, content, parent) + return UIUtils.headerSparkPage(request, stageHeader, content, parent) } val storedTasks = store.taskCount(stageData.stageId, stageData.attemptId) @@ -282,7 +282,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( stageData, - UIUtils.prependBaseUri(parent.basePath) + + UIUtils.prependBaseUri(request, parent.basePath) + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", currentTime, pageSize = taskPageSize, @@ -498,7 +498,7 @@ private[ui] class StagePage(parent: StagesTab, store: AppStatusStore) extends We
    {taskTableHTML ++ jsForScrollingDownToTaskTable}
    - UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) + UIUtils.headerSparkPage(request, stageHeader, content, parent, showVisualization = true) } def makeTimeline(tasks: Seq[TaskData], currentTime: Long): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 18a4926f2f6c0..b8b20db1fa407 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -92,7 +92,8 @@ private[ui] class StageTableBase( stageSortColumn, stageSortDesc, isFailedStage, - parameterOtherTable + parameterOtherTable, + request ).table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => @@ -147,7 +148,8 @@ private[ui] class StagePagedTable( sortColumn: String, desc: Boolean, isFailedStage: Boolean, - parameterOtherTable: Iterable[String]) extends PagedTable[StageTableRowData] { + parameterOtherTable: Iterable[String], + request: HttpServletRequest) extends PagedTable[StageTableRowData] { override def tableId: String = stageTag + "-table" @@ -161,7 +163,7 @@ private[ui] class StagePagedTable( override def pageNumberFormField: String = stageTag + ".page" - val parameterPath = UIUtils.prependBaseUri(basePath) + s"/$subPath/?" + + val parameterPath = UIUtils.prependBaseUri(request, basePath) + s"/$subPath/?" + parameterOtherTable.mkString("&") override val dataSource = new StageDataSource( @@ -288,7 +290,7 @@ private[ui] class StagePagedTable( {if (isFairScheduler) { + .format(UIUtils.prependBaseUri(request, basePath), data.schedulingPool)}> {data.schedulingPool} @@ -346,7 +348,7 @@ private[ui] class StagePagedTable( } private def makeDescription(s: v1.StageData, descriptionOption: Option[String]): Seq[Node] = { - val basePathUri = UIUtils.prependBaseUri(basePath) + val basePathUri = UIUtils.prependBaseUri(request, basePath) val killLink = if (killEnabled) { val confirm = diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 2674b9291203a..238cd31433660 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -53,7 +53,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } catch { case _: NoSuchElementException => // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage("RDD Not Found", Seq.empty[Node], parent) + return UIUtils.headerSparkPage(request, "RDD Not Found", Seq.empty[Node], parent) } // Worker table @@ -72,7 +72,7 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web } val blockTableHTML = try { val _blockTable = new BlockPagedTable( - UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", + UIUtils.prependBaseUri(request, parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, @@ -145,7 +145,8 @@ private[ui] class RDDPage(parent: SparkUITab, store: AppStatusStore) extends Web {blockTableHTML ++ jsForScrollingDownToBlockTable} ; - UIUtils.headerSparkPage("RDD Storage Info for " + rddStorageInfo.name, content, parent) + UIUtils.headerSparkPage( + request, "RDD Storage Info for " + rddStorageInfo.name, content, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 68d946574a37b..3eb546e336e99 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -31,11 +31,14 @@ import org.apache.spark.util.Utils private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { - val content = rddTable(store.rddList()) ++ receiverBlockTables(store.streamBlocksList()) - UIUtils.headerSparkPage("Storage", content, parent) + val content = rddTable(request, store.rddList()) ++ + receiverBlockTables(store.streamBlocksList()) + UIUtils.headerSparkPage(request, "Storage", content, parent) } - private[storage] def rddTable(rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { + private[storage] def rddTable( + request: HttpServletRequest, + rdds: Seq[v1.RDDStorageInfo]): Seq[Node] = { if (rdds.isEmpty) { // Don't show the rdd table if there is no RDD persisted. Nil @@ -49,7 +52,11 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends
    - {UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))} + {UIUtils.listingTable( + rddHeader, + rddRow(request, _: v1.RDDStorageInfo), + rdds, + id = Some("storage-by-rdd-table"))}
    } @@ -66,12 +73,13 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size on Disk") /** Render an HTML row representing an RDD */ - private def rddRow(rdd: v1.RDDStorageInfo): Seq[Node] = { + private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off {rdd.id} - + {rdd.name} diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 2bc84953a56eb..3b469a69437b9 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -24,6 +24,7 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext} +import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo private[spark] case class AccumulatorMetadata( @@ -211,7 +212,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { /** * An internal class used to track accumulators by Spark itself. */ -private[spark] object AccumulatorContext { +private[spark] object AccumulatorContext extends Logging { /** * This global map holds the original accumulator objects that are created on the driver. @@ -258,13 +259,16 @@ private[spark] object AccumulatorContext { * Returns the [[AccumulatorV2]] registered with the given ID, if any. */ def get(id: Long): Option[AccumulatorV2[_, _]] = { - Option(originals.get(id)).map { ref => - // Since we are storing weak references, we must check whether the underlying data is valid. + val ref = originals.get(id) + if (ref eq null) { + None + } else { + // Since we are storing weak references, warn when the underlying data is not valid. val acc = ref.get if (acc eq null) { - throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id") + logWarning(s"Attempted to access garbage collected accumulator $id") } - acc + Option(acc) } } diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 3ea9139e11027..651ea4996f6cb 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -37,7 +37,8 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { private val stopped = new AtomicBoolean(false) - private val eventThread = new Thread(name) { + // Exposed for testing. + private[spark] val eventThread = new Thread(name) { setDaemon(true) override def run(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 40383fe05026b..50c6461373dee 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -407,7 +407,9 @@ private[spark] object JsonProtocol { ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) case taskKilled: TaskKilled => - ("Kill Reason" -> taskKilled.reason) + val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList) + ("Kill Reason" -> taskKilled.reason) ~ + ("Accumulator Updates" -> accumUpdates) case _ => emptyJson } ("Reason" -> reason) ~ json @@ -917,7 +919,10 @@ private[spark] object JsonProtocol { case `taskKilled` => val killReason = jsonOption(json \ "Kill Reason") .map(_.extract[String]).getOrElse("unknown reason") - TaskKilled(killReason) + val accumUpdates = jsonOption(json \ "Accumulator Updates") + .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) + .getOrElse(Seq[AccumulableInfo]()) + TaskKilled(killReason, accumUpdates) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index b25a731401f23..d4474a90b26f1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -60,6 +60,15 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } } + /** + * This can be overriden by subclasses if there is any extra cleanup to do when removing a + * listener. In particular AsyncEventQueues can clean up queues in the LiveListenerBus. + */ + def removeListenerOnError(listener: L): Unit = { + removeListener(listener) + } + + /** * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. @@ -80,7 +89,16 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { } try { doPostEvent(listener, event) + if (Thread.interrupted()) { + // We want to throw the InterruptedException right away so we can associate the interrupt + // with this listener, as opposed to waiting for a queue.take() etc. to detect it. + throw new InterruptedException() + } } catch { + case ie: InterruptedException => + logError(s"Interrupted while posting to ${Utils.getFormattedClassName(listener)}. " + + s"Removing that listener.", ie) + removeListenerOnError(listener) case NonFatal(e) if !isIgnorableException(e) => logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) } finally { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 13adaa921dc23..f9191a59c1655 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -810,15 +810,15 @@ private[spark] object Utils extends Logging { conf.getenv("SPARK_EXECUTOR_DIRS").split(File.pathSeparator) } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { conf.getenv("SPARK_LOCAL_DIRS").split(",") - } else if (conf.getenv("MESOS_DIRECTORY") != null && !shuffleServiceEnabled) { + } else if (conf.getenv("MESOS_SANDBOX") != null && !shuffleServiceEnabled) { // Mesos already creates a directory per Mesos task. Spark should use that directory // instead so all temporary files are automatically cleaned up when the Mesos task ends. // Note that we don't want this if the shuffle service is enabled because we want to // continue to serve shuffle files after the executors that wrote them have already exited. - Array(conf.getenv("MESOS_DIRECTORY")) + Array(conf.getenv("MESOS_SANDBOX")) } else { - if (conf.getenv("MESOS_DIRECTORY") != null && shuffleServiceEnabled) { - logInfo("MESOS_DIRECTORY available but not using provided Mesos sandbox because " + + if (conf.getenv("MESOS_SANDBOX") != null && shuffleServiceEnabled) { + logInfo("MESOS_SANDBOX available but not using provided Mesos sandbox because " + "spark.shuffle.service.enabled is enabled.") } // In non-Yarn mode (or for the driver in yarn-client mode), we cannot trust the user diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 3ae8dfcc1cb66..700ce56466c35 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -63,15 +63,19 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { */ def writeFully(channel: WritableByteChannel): Unit = { for (bytes <- getChunks()) { - val curChunkLimit = bytes.limit() + val originalLimit = bytes.limit() while (bytes.hasRemaining) { - try { - val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) - bytes.limit(bytes.position() + ioSize) - channel.write(bytes) - } finally { - bytes.limit(curChunkLimit) - } + // If `bytes` is an on-heap ByteBuffer, the Java NIO API will copy it to a temporary direct + // ByteBuffer when writing it out. This temporary direct ByteBuffer is cached per thread. + // Its size has no limit and can keep growing if it sees a larger input ByteBuffer. This may + // cause significant native memory leak, if a large direct ByteBuffer is allocated and + // cached, as it's never released until thread exits. Here we write the `bytes` with + // fixed-size slices to limit the size of the cached direct ByteBuffer. + // Please refer to http://www.evanjones.ca/java-bytebuffer-leak.html for more details. + val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize) + bytes.limit(bytes.position() + ioSize) + channel.write(bytes) + bytes.limit(originalLimit) } } } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 3990ee1ec326d..5d0ffd92647bc 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -209,10 +209,8 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex System.gc() assert(ref.get.isEmpty) - // Getting a garbage collected accum should throw error - intercept[IllegalStateException] { - AccumulatorContext.get(accId) - } + // Getting a garbage collected accum should return None. + assert(AccumulatorContext.get(accId).isEmpty) // Getting a normal accumulator. Note: this has to be separate because referencing an // accumulator above in an `assert` would keep it from being garbage collected. diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7451e07b25a1f..43286953e4383 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -180,6 +180,26 @@ class SparkSubmitSuite appArgs.toString should include ("thequeue") } + test("SPARK-24241: do not fail fast if executor num is 0 when dynamic allocation is enabled") { + val clArgs1 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=true", + "thejar.jar") + new SparkSubmitArguments(clArgs1) + + val clArgs2 = Seq( + "--name", "myApp", + "--class", "Foo", + "--num-executors", "0", + "--conf", "spark.dynamicAllocation.enabled=false", + "thejar.jar") + + val e = intercept[SparkException](new SparkSubmitArguments(clArgs2)) + assert(e.getMessage.contains("Number of executors must be a positive number")) + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a871b1c717837..11b29121739a4 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -36,6 +36,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito._ import org.openqa.selenium.WebDriver import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest.{BeforeAndAfter, Matchers} @@ -281,6 +282,29 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } + test("automatically retrieve uiRoot from request through Knox") { + assert(sys.props.get("spark.ui.proxyBase").isEmpty, + "spark.ui.proxyBase is defined but it should not for this UT") + assert(sys.env.get("APPLICATION_WEB_PROXY_BASE").isEmpty, + "APPLICATION_WEB_PROXY_BASE is defined but it should not for this UT") + val page = new HistoryPage(server) + val requestThroughKnox = mock[HttpServletRequest] + val knoxBaseUrl = "/gateway/default/sparkhistoryui" + when(requestThroughKnox.getHeader("X-Forwarded-Context")).thenReturn(knoxBaseUrl) + val responseThroughKnox = page.render(requestThroughKnox) + + val urlsThroughKnox = responseThroughKnox \\ "@href" map (_.toString) + val siteRelativeLinksThroughKnox = urlsThroughKnox filter (_.startsWith("/")) + all (siteRelativeLinksThroughKnox) should startWith (knoxBaseUrl) + + val directRequest = mock[HttpServletRequest] + val directResponse = page.render(directRequest) + + val directUrls = directResponse \\ "@href" map (_.toString) + val directSiteRelativeLinks = directUrls filter (_.startsWith("/")) + all (directSiteRelativeLinks) should not startWith (knoxBaseUrl) + } + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8b6ec37625eec..2987170bf5026 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assertDataStructuresEmpty() } - test("accumulators are updated on exception failures") { + test("accumulators are updated on exception failures and task killed") { val acc1 = AccumulatorSuite.createLongAccum("ingenieur") val acc2 = AccumulatorSuite.createLongAccum("boulanger") val acc3 = AccumulatorSuite.createLongAccum("agriculteur") @@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accUpdate3 = new LongAccumulator accUpdate3.metadata = acc3.metadata accUpdate3.setValue(18) - val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3) - val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo) + + val accumUpdates1 = Seq(accUpdate1, accUpdate2) + val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo) val exceptionFailure = new ExceptionFailure( new SparkException("fondue?"), - accumInfo).copy(accums = accumUpdates) + accumInfo1).copy(accums = accumUpdates1) submit(new MyRDD(sc, 1, Nil), Array(0)) runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result")) + assert(AccumulatorContext.get(acc1.id).get.value === 15L) assert(AccumulatorContext.get(acc2.id).get.value === 13L) + + val accumUpdates2 = Seq(accUpdate3) + val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo) + + val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2) + runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result")) + assert(AccumulatorContext.get(acc3.id).get.value === 18L) } @@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val accumUpdates = reason match { case Success => task.metrics.accumulators() case ef: ExceptionFailure => ef.accums + case tk: TaskKilled => tk.accums case _ => Seq.empty } CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index fa47a52bbbc47..6ffd1e84f7adb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -489,6 +489,48 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) } + Seq(true, false).foreach { throwInterruptedException => + val suffix = if (throwInterruptedException) "throw interrupt" else "set Thread interrupted" + test(s"interrupt within listener is handled correctly: $suffix") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new InterruptingListener(throwInterruptedException) + val interruptingListener2 = new InterruptingListener(throwInterruptedException) + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 2) + + bus.start(mockSparkContext, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[InterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(counter1.count === 6) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ @@ -547,6 +589,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { throw new Exception } } + /** + * A simple listener that interrupts on job end. + */ + private class InterruptingListener(val throwInterruptedException: Boolean) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (throwInterruptedException) { + throw new InterruptedException("got interrupted") + } else { + Thread.currentThread().interrupt() + } + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index a71521c91d2f2..cdc7f541b9552 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.storage +import javax.servlet.http.HttpServletRequest + import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite @@ -29,6 +31,7 @@ class StoragePageSuite extends SparkFunSuite { val storageTab = mock(classOf[StorageTab]) when(storageTab.basePath).thenReturn("http://localhost:4040") val storagePage = new StoragePage(storageTab, null) + val request = mock(classOf[HttpServletRequest]) test("rddTable") { val rdd1 = new RDDStorageInfo(1, @@ -61,7 +64,7 @@ class StoragePageSuite extends SparkFunSuite { None, None) - val xmlNodes = storagePage.rddTable(Seq(rdd1, rdd2, rdd3)) + val xmlNodes = storagePage.rddTable(request, Seq(rdd1, rdd2, rdd3)) val headers = Seq( "ID", @@ -94,7 +97,7 @@ class StoragePageSuite extends SparkFunSuite { } test("empty rddTable") { - assert(storagePage.rddTable(Seq.empty).isEmpty) + assert(storagePage.rddTable(request, Seq.empty).isEmpty) } test("streamBlockStorageLevelDescriptionAndSize") { diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index c00b00b845401..5faa3d3260a56 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -371,11 +371,18 @@ if [[ "$1" == "publish-release" ]]; then find . -type f |grep -v \.jar |grep -v \.pom | xargs rm echo "Creating hash and signature files" - # this must have .asc and .sha1 - it really doesn't like anything else there + # this must have .asc, .md5 and .sha1 - it really doesn't like anything else there for file in $(find . -type f) do echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --output $file.asc \ --detach-sig --armour $file; + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi sha1sum $file | cut -f1 -d' ' > $file.sha1 done diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index f552b81fde9f4..e710e26348117 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -190,7 +190,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 024b1fca717df..97ad17a9ff7b1 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -191,7 +191,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar diff --git a/dev/deps/spark-deps-hadoop-3.1 b/dev/deps/spark-deps-hadoop-3.1 index 938de7bc06663..e21bfef8c4291 100644 --- a/dev/deps/spark-deps-hadoop-3.1 +++ b/dev/deps/spark-deps-hadoop-3.1 @@ -211,7 +211,7 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar token-provider-1.0.1.jar -univocity-parsers-2.5.9.jar +univocity-parsers-2.6.3.jar validation-api-1.1.0.Final.jar woodstox-core-5.0.3.jar xbean-asm5-shaded-4.4.jar diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 5ea205fbed4aa..7f46a1c8f6a7c 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -101,14 +101,15 @@ def continue_maybe(prompt): def clean_up(): - print("Restoring head pointer to %s" % original_head) - run_cmd("git checkout %s" % original_head) + if 'original_head' in globals(): + print("Restoring head pointer to %s" % original_head) + run_cmd("git checkout %s" % original_head) - branches = run_cmd("git branch").replace(" ", "").split("\n") + branches = run_cmd("git branch").replace(" ", "").split("\n") - for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): - print("Deleting local branch %s" % branch) - run_cmd("git branch -D %s" % branch) + for branch in filter(lambda x: x.startswith(BRANCH_PREFIX), branches): + print("Deleting local branch %s" % branch) + run_cmd("git branch -D %s" % branch) # merge the requested PR and return the merge hash diff --git a/docs/configuration.md b/docs/configuration.md index 8a1aacef85760..fd2670cba2125 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -208,7 +208,7 @@ of the most common options to set are: stored on disk. This should be on a fast, local disk in your system. It can also be a comma-separated list of multiple directories on different disks. - NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone, Mesos) or + NOTE: In Spark 1.0 and later this will be overridden by SPARK_LOCAL_DIRS (Standalone), MESOS_SANDBOX (Mesos) or LOCAL_DIRS (YARN) environment variables set by the cluster manager. diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c9e68c3bfd056..4dbcbeafbbd9d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -424,9 +424,12 @@ To use a custom metrics.properties for the application master and executors, upd Standard Kerberos support in Spark is covered in the [Security](security.html#kerberos) page. -In YARN mode, when accessing Hadoop file systems, aside from the service hosting the user's home -directory, Spark will also automatically obtain delegation tokens for the service hosting the -staging directory of the Spark application. +In YARN mode, when accessing Hadoop filesystems, Spark will automatically obtain delegation tokens +for: + +- the filesystem hosting the staging directory of the Spark application (which is the default + filesystem if `spark.yarn.stagingDir` is not set); +- if Hadoop federation is enabled, all the federated filesystems in the configuration. If an application needs to interact with other secure Hadoop filesystems, their URIs need to be explicitly provided to Spark at launch time. This is done by listing them in the diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f79ed6422205..fc26562ff33da 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1214,7 +1214,7 @@ The following options can be used to configure the version of Hive that is used 1.2.1 Version of the Hive metastore. Available - options are 0.12.0 through 2.3.2. + options are 0.12.0 through 2.3.3. @@ -1338,6 +1338,17 @@ the following case-insensitive options: + + queryTimeout + + The number of seconds the driver will wait for a Statement object to execute to the given + number of seconds. Zero means there is no limit. In the write path, this option depends on + how JDBC drivers implement the API setQueryTimeout, e.g., the h2 JDBC driver + checks the timeout of each query instead of an entire JDBC batch. + It defaults to 0. + + + fetchsize @@ -1814,6 +1825,7 @@ working with timestamps in `pandas_udf`s to get the best performance, see - In version 2.3 and earlier, `to_utc_timestamp` and `from_utc_timestamp` respect the timezone in the input timestamp string, which breaks the assumption that the input timestamp is in a specific timezone. Therefore, these 2 functions can return unexpected results. In version 2.4 and later, this problem has been fixed. `to_utc_timestamp` and `from_utc_timestamp` will return null if the input timestamp string contains timezone. As an example, `from_utc_timestamp('2000-10-10 00:00:00', 'GMT+1')` will return `2000-10-10 01:00:00` in both Spark 2.3 and 2.4. However, `from_utc_timestamp('2000-10-10 00:00:00+00:00', 'GMT+1')`, assuming a local timezone of GMT+8, will return `2000-10-10 09:00:00` in Spark 2.3 but `null` in 2.4. For people who don't care about this problem and want to retain the previous behaivor to keep their query unchanged, you can set `spark.sql.function.rejectTimezoneInString` to false. This option will be removed in Spark 3.0 and should only be used as a temporary workaround. - In version 2.3 and earlier, Spark converts Parquet Hive tables by default but ignores table properties like `TBLPROPERTIES (parquet.compression 'NONE')`. This happens for ORC Hive table properties like `TBLPROPERTIES (orc.compress 'NONE')` in case of `spark.sql.hive.convertMetastoreOrc=true`, too. Since Spark 2.4, Spark respects Parquet/ORC specific table properties while converting Parquet/ORC Hive tables. As an example, `CREATE TABLE t(id int) STORED AS PARQUET TBLPROPERTIES (parquet.compression 'NONE')` would generate Snappy parquet files during insertion in Spark 2.3, and in Spark 2.4, the result would be uncompressed parquet files. - Since Spark 2.0, Spark converts Parquet Hive tables by default for better performance. Since Spark 2.4, Spark converts ORC Hive tables by default, too. It means Spark uses its own ORC support by default instead of Hive SerDe. As an example, `CREATE TABLE t(id int) STORED AS ORC` would be handled with Hive SerDe in Spark 2.3, and in Spark 2.4, it would be converted into Spark's ORC data source table and ORC vectorization would be applied. To set `false` to `spark.sql.hive.convertMetastoreOrc` restores the previous behavior. + - In version 2.3 and earlier, CSV rows are considered as malformed if at least one column value in the row is malformed. CSV parser dropped such rows in the DROPMALFORMED mode or outputs an error in the FAILFAST mode. Since Spark 2.4, CSV row is considered as malformed only when it contains malformed column values requested from CSV datasource, other values can be ignored. As an example, CSV file contains the "id,name" header and one row "1234". In Spark 2.4, selection of the id column consists of a row with one column value 1234 but in Spark 2.3 and earlier it is empty in the DROPMALFORMED mode. To restore the previous behavior, set `spark.sql.csv.parser.columnPruning.enabled` to `false`. ## Upgrading From Spark SQL 2.2 to 2.3 @@ -2237,7 +2249,7 @@ referencing a singleton. Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently, Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 2.3.2. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). +(from 0.12.0 to 2.3.3. Also see [Interacting with Different Versions of Hive Metastore](#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala index 88abf8a8dd027..badaa69cc303c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala @@ -106,7 +106,7 @@ class KafkaContinuousReader( startOffsets.toSeq.map { case (topicPartition, start) => - KafkaContinuousDataReaderFactory( + KafkaContinuousInputPartition( topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss) .asInstanceOf[InputPartition[UnsafeRow]] }.asJava @@ -146,7 +146,7 @@ class KafkaContinuousReader( } /** - * A data reader factory for continuous Kafka processing. This will be serialized and transformed + * An input partition for continuous Kafka processing. This will be serialized and transformed * into a full reader on executors. * * @param topicPartition The (topic, partition) pair this task is responsible for. @@ -156,7 +156,7 @@ class KafkaContinuousReader( * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets * are skipped. */ -case class KafkaContinuousDataReaderFactory( +case class KafkaContinuousInputPartition( topicPartition: TopicPartition, startOffset: Long, kafkaParams: ju.Map[String, Object], diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala index 48508d057a540..941f0ab177e48 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataConsumer.scala @@ -395,7 +395,7 @@ private[kafka010] object KafkaDataConsumer extends Logging { // likely running on a beefy machine that can handle a large number of simultaneously // active consumers. - if (entry.getValue.inUse == false && this.size > capacity) { + if (!entry.getValue.inUse && this.size > capacity) { logWarning( s"KafkaConsumer cache hitting max capacity of $capacity, " + s"removing consumer for ${entry.getKey}") diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala index 8a377738ea782..64ba98762788c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchReader.scala @@ -143,7 +143,7 @@ private[kafka010] class KafkaMicroBatchReader( // Generate factories based on the offset ranges val factories = offsetRanges.map { range => - new KafkaMicroBatchDataReaderFactory( + new KafkaMicroBatchInputPartition( range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer) } factories.map(_.asInstanceOf[InputPartition[UnsafeRow]]).asJava @@ -300,7 +300,7 @@ private[kafka010] class KafkaMicroBatchReader( } /** A [[InputPartition]] for reading Kafka data in a micro-batch streaming query. */ -private[kafka010] case class KafkaMicroBatchDataReaderFactory( +private[kafka010] case class KafkaMicroBatchInputPartition( offsetRange: KafkaOffsetRange, executorKafkaParams: ju.Map[String, Object], pollTimeoutMs: Long, diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 871f9700cd1db..c6412eac97dba 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -679,7 +679,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L))) ) val factories = reader.planUnsafeInputPartitions().asScala - .map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory]) + .map(_.asInstanceOf[KafkaMicroBatchInputPartition]) withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") { assert(factories.size == numPartitionsGenerated) factories.foreach { f => assert(f.reuseKafkaConsumer == reusesConsumers) } diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala deleted file mode 100644 index aeb8c1dc342b3..0000000000000 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.kafka010 - -import java.{ util => ju } - -import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer } -import org.apache.kafka.common.{ KafkaException, TopicPartition } - -import org.apache.spark.internal.Logging - -/** - * Consumer of single topicpartition, intended for cached reuse. - * Underlying consumer is not threadsafe, so neither is this, - * but processing the same topicpartition and group id in multiple threads is usually bad anyway. - */ -private[kafka010] -class CachedKafkaConsumer[K, V] private( - val groupId: String, - val topic: String, - val partition: Int, - val kafkaParams: ju.Map[String, Object]) extends Logging { - - require(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG), - "groupId used for cache key must match the groupId in kafkaParams") - - val topicPartition = new TopicPartition(topic, partition) - - protected val consumer = { - val c = new KafkaConsumer[K, V](kafkaParams) - val tps = new ju.ArrayList[TopicPartition]() - tps.add(topicPartition) - c.assign(tps) - c - } - - // TODO if the buffer was kept around as a random-access structure, - // could possibly optimize re-calculating of an RDD in the same batch - protected var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() - protected var nextOffset = -2L - - def close(): Unit = consumer.close() - - /** - * Get the record for the given offset, waiting up to timeout ms if IO is necessary. - * Sequential forward access will use buffers, but random access will be horribly inefficient. - */ - def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { - logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset") - if (offset != nextOffset) { - logInfo(s"Initial fetch for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - - if (!buffer.hasNext()) { poll(timeout) } - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - var record = buffer.next() - - if (record.offset != offset) { - logInfo(s"Buffer miss for $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - require(buffer.hasNext(), - s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout") - record = buffer.next() - require(record.offset == offset, - s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset " + - s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + - "spark.streaming.kafka.allowNonConsecutiveOffsets" - ) - } - - nextOffset = offset + 1 - record - } - - /** - * Start a batch on a compacted topic - */ - def compactedStart(offset: Long, timeout: Long): Unit = { - logDebug(s"compacted start $groupId $topic $partition starting $offset") - // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics - if (offset != nextOffset) { - logInfo(s"Initial fetch for compacted $groupId $topic $partition $offset") - seek(offset) - poll(timeout) - } - } - - /** - * Get the next record in the batch from a compacted topic. - * Assumes compactedStart has been called first, and ignores gaps. - */ - def compactedNext(timeout: Long): ConsumerRecord[K, V] = { - if (!buffer.hasNext()) { - poll(timeout) - } - require(buffer.hasNext(), - s"Failed to get records for compacted $groupId $topic $partition after polling for $timeout") - val record = buffer.next() - nextOffset = record.offset + 1 - record - } - - /** - * Rewind to previous record in the batch from a compacted topic. - * @throws NoSuchElementException if no previous element - */ - def compactedPrevious(): ConsumerRecord[K, V] = { - buffer.previous() - } - - private def seek(offset: Long): Unit = { - logDebug(s"Seeking to $topicPartition $offset") - consumer.seek(topicPartition, offset) - } - - private def poll(timeout: Long): Unit = { - val p = consumer.poll(timeout) - val r = p.records(topicPartition) - logDebug(s"Polled ${p.partitions()} ${r.size}") - buffer = r.listIterator - } - -} - -private[kafka010] -object CachedKafkaConsumer extends Logging { - - private case class CacheKey(groupId: String, topic: String, partition: Int) - - // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap - private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null - - /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */ - def init( - initialCapacity: Int, - maxCapacity: Int, - loadFactor: Float): Unit = CachedKafkaConsumer.synchronized { - if (null == cache) { - logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") - cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]]( - initialCapacity, loadFactor, true) { - override def removeEldestEntry( - entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = { - if (this.size > maxCapacity) { - try { - entry.getValue.consumer.close() - } catch { - case x: KafkaException => - logError("Error closing oldest Kafka consumer", x) - } - true - } else { - false - } - } - } - } - } - - /** - * Get a cached consumer for groupId, assigned to topic and partition. - * If matching consumer doesn't already exist, will be created using kafkaParams. - */ - def get[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - CachedKafkaConsumer.synchronized { - val k = CacheKey(groupId, topic, partition) - val v = cache.get(k) - if (null == v) { - logInfo(s"Cache miss for $k") - logDebug(cache.keySet.toString) - val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - cache.put(k, c) - c - } else { - // any given topicpartition should have a consistent key and value type - v.asInstanceOf[CachedKafkaConsumer[K, V]] - } - } - - /** - * Get a fresh new instance, unassociated with the global cache. - * Caller is responsible for closing - */ - def getUncached[K, V]( - groupId: String, - topic: String, - partition: Int, - kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] = - new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams) - - /** remove consumer for given groupId, topic, and partition, if it exists */ - def remove(groupId: String, topic: String, partition: Int): Unit = { - val k = CacheKey(groupId, topic, partition) - logInfo(s"Removing $k from cache") - val v = CachedKafkaConsumer.synchronized { - cache.remove(k) - } - if (null != v) { - v.close() - logInfo(s"Removed $k from cache") - } - } -} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala new file mode 100644 index 0000000000000..68c5fe9ab066a --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumer.scala @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging + +private[kafka010] sealed trait KafkaDataConsumer[K, V] { + /** + * Get the record for the given offset if available. + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def get(offset: Long, pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.get(offset, pollTimeoutMs) + } + + /** + * Start a batch on a compacted topic + * + * @param offset the offset to fetch. + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + internalConsumer.compactedStart(offset, pollTimeoutMs) + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + * + * @param pollTimeoutMs timeout in milliseconds to poll data from Kafka. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + internalConsumer.compactedNext(pollTimeoutMs) + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + internalConsumer.compactedPrevious() + } + + /** + * Release this consumer from being further used. Depending on its implementation, + * this consumer will be either finalized, or reset for reuse later. + */ + def release(): Unit + + /** Reference to the internal implementation that this wrapper delegates to */ + def internalConsumer: InternalKafkaConsumer[K, V] +} + + +/** + * A wrapper around Kafka's KafkaConsumer. + * This is not for direct use outside this file. + */ +private[kafka010] class InternalKafkaConsumer[K, V]( + val topicPartition: TopicPartition, + val kafkaParams: ju.Map[String, Object]) extends Logging { + + private[kafka010] val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG) + .asInstanceOf[String] + + private val consumer = createConsumer + + /** indicates whether this consumer is in use or not */ + var inUse = true + + /** indicate whether this consumer is going to be stopped in the next release */ + var markedForClose = false + + // TODO if the buffer was kept around as a random-access structure, + // could possibly optimize re-calculating of an RDD in the same batch + @volatile private var buffer = ju.Collections.emptyListIterator[ConsumerRecord[K, V]]() + @volatile private var nextOffset = InternalKafkaConsumer.UNKNOWN_OFFSET + + override def toString: String = { + "InternalKafkaConsumer(" + + s"hash=${Integer.toHexString(hashCode)}, " + + s"groupId=$groupId, " + + s"topicPartition=$topicPartition)" + } + + /** Create a KafkaConsumer to fetch records for `topicPartition` */ + private def createConsumer: KafkaConsumer[K, V] = { + val c = new KafkaConsumer[K, V](kafkaParams) + val topics = ju.Arrays.asList(topicPartition) + c.assign(topics) + c + } + + def close(): Unit = consumer.close() + + /** + * Get the record for the given offset, waiting up to timeout ms if IO is necessary. + * Sequential forward access will use buffers, but random access will be horribly inefficient. + */ + def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = { + logDebug(s"Get $groupId $topicPartition nextOffset $nextOffset requested $offset") + if (offset != nextOffset) { + logInfo(s"Initial fetch for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + } + + if (!buffer.hasNext()) { + poll(timeout) + } + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + var record = buffer.next() + + if (record.offset != offset) { + logInfo(s"Buffer miss for $groupId $topicPartition $offset") + seek(offset) + poll(timeout) + require(buffer.hasNext(), + s"Failed to get records for $groupId $topicPartition $offset after polling for $timeout") + record = buffer.next() + require(record.offset == offset, + s"Got wrong record for $groupId $topicPartition even after seeking to offset $offset " + + s"got offset ${record.offset} instead. If this is a compacted topic, consider enabling " + + "spark.streaming.kafka.allowNonConsecutiveOffsets" + ) + } + + nextOffset = offset + 1 + record + } + + /** + * Start a batch on a compacted topic + */ + def compactedStart(offset: Long, pollTimeoutMs: Long): Unit = { + logDebug(s"compacted start $groupId $topicPartition starting $offset") + // This seek may not be necessary, but it's hard to tell due to gaps in compacted topics + if (offset != nextOffset) { + logInfo(s"Initial fetch for compacted $groupId $topicPartition $offset") + seek(offset) + poll(pollTimeoutMs) + } + } + + /** + * Get the next record in the batch from a compacted topic. + * Assumes compactedStart has been called first, and ignores gaps. + */ + def compactedNext(pollTimeoutMs: Long): ConsumerRecord[K, V] = { + if (!buffer.hasNext()) { + poll(pollTimeoutMs) + } + require(buffer.hasNext(), + s"Failed to get records for compacted $groupId $topicPartition " + + s"after polling for $pollTimeoutMs") + val record = buffer.next() + nextOffset = record.offset + 1 + record + } + + /** + * Rewind to previous record in the batch from a compacted topic. + * @throws NoSuchElementException if no previous element + */ + def compactedPrevious(): ConsumerRecord[K, V] = { + buffer.previous() + } + + private def seek(offset: Long): Unit = { + logDebug(s"Seeking to $topicPartition $offset") + consumer.seek(topicPartition, offset) + } + + private def poll(timeout: Long): Unit = { + val p = consumer.poll(timeout) + val r = p.records(topicPartition) + logDebug(s"Polled ${p.partitions()} ${r.size}") + buffer = r.listIterator + } + +} + +private[kafka010] case class CacheKey(groupId: String, topicPartition: TopicPartition) + +private[kafka010] object KafkaDataConsumer extends Logging { + + private case class CachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + assert(internalConsumer.inUse) + override def release(): Unit = KafkaDataConsumer.release(internalConsumer) + } + + private case class NonCachedKafkaDataConsumer[K, V](internalConsumer: InternalKafkaConsumer[K, V]) + extends KafkaDataConsumer[K, V] { + override def release(): Unit = internalConsumer.close() + } + + // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap + private[kafka010] var cache: ju.Map[CacheKey, InternalKafkaConsumer[_, _]] = null + + /** + * Must be called before acquire, once per JVM, to configure the cache. + * Further calls are ignored. + */ + def init( + initialCapacity: Int, + maxCapacity: Int, + loadFactor: Float): Unit = synchronized { + if (null == cache) { + logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor") + cache = new ju.LinkedHashMap[CacheKey, InternalKafkaConsumer[_, _]]( + initialCapacity, loadFactor, true) { + override def removeEldestEntry( + entry: ju.Map.Entry[CacheKey, InternalKafkaConsumer[_, _]]): Boolean = { + + // Try to remove the least-used entry if its currently not in use. + // + // If you cannot remove it, then the cache will keep growing. In the worst case, + // the cache will grow to the max number of concurrent tasks that can run in the executor, + // (that is, number of tasks slots) after which it will never reduce. This is unlikely to + // be a serious problem because an executor with more than 64 (default) tasks slots is + // likely running on a beefy machine that can handle a large number of simultaneously + // active consumers. + + if (entry.getValue.inUse == false && this.size > maxCapacity) { + logWarning( + s"KafkaConsumer cache hitting max capacity of $maxCapacity, " + + s"removing consumer for ${entry.getKey}") + try { + entry.getValue.close() + } catch { + case x: KafkaException => + logError("Error closing oldest Kafka consumer", x) + } + true + } else { + false + } + } + } + } + } + + /** + * Get a cached consumer for groupId, assigned to topic and partition. + * If matching consumer doesn't already exist, will be created using kafkaParams. + * The returned consumer must be released explicitly using [[KafkaDataConsumer.release()]]. + * + * Note: This method guarantees that the consumer returned is not currently in use by anyone + * else. Within this guarantee, this method will make a best effort attempt to re-use consumers by + * caching them and tracking when they are in use. + */ + def acquire[K, V]( + topicPartition: TopicPartition, + kafkaParams: ju.Map[String, Object], + context: TaskContext, + useCache: Boolean): KafkaDataConsumer[K, V] = synchronized { + val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = cache.get(key) + + lazy val newInternalConsumer = new InternalKafkaConsumer[K, V](topicPartition, kafkaParams) + + if (context != null && context.attemptNumber >= 1) { + // If this is reattempt at running the task, then invalidate cached consumers if any and + // start with a new one. If prior attempt failures were cache related then this way old + // problematic consumers can be removed. + logDebug(s"Reattempt detected, invalidating cached consumer $existingInternalConsumer") + if (existingInternalConsumer != null) { + // Consumer exists in cache. If its in use, mark it for closing later, or close it now. + if (existingInternalConsumer.inUse) { + existingInternalConsumer.markedForClose = true + } else { + existingInternalConsumer.close() + // Remove the consumer from cache only if it's closed. + // Marked for close consumers will be removed in release function. + cache.remove(key) + } + } + + logDebug("Reattempt detected, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (!useCache) { + // If consumer reuse turned off, then do not use it, return a new consumer + logDebug("Cache usage turned off, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer == null) { + // If consumer is not already cached, then put a new in the cache and return it + logDebug("No cached consumer, new cached consumer will be allocated " + + s"$newInternalConsumer") + cache.put(key, newInternalConsumer) + CachedKafkaDataConsumer(newInternalConsumer) + } else if (existingInternalConsumer.inUse) { + // If consumer is already cached but is currently in use, then return a new consumer + logDebug("Used cached consumer found, new non-cached consumer will be allocated " + + s"$newInternalConsumer") + NonCachedKafkaDataConsumer(newInternalConsumer) + } else { + // If consumer is already cached and is currently not in use, then return that consumer + logDebug(s"Not used cached consumer found, re-using it $existingInternalConsumer") + existingInternalConsumer.inUse = true + // Any given TopicPartition should have a consistent key and value type + CachedKafkaDataConsumer(existingInternalConsumer.asInstanceOf[InternalKafkaConsumer[K, V]]) + } + } + + private def release(internalConsumer: InternalKafkaConsumer[_, _]): Unit = synchronized { + // Clear the consumer from the cache if this is indeed the consumer present in the cache + val key = new CacheKey(internalConsumer.groupId, internalConsumer.topicPartition) + val cachedInternalConsumer = cache.get(key) + if (internalConsumer.eq(cachedInternalConsumer)) { + // The released consumer is the same object as the cached one. + if (internalConsumer.markedForClose) { + internalConsumer.close() + cache.remove(key) + } else { + internalConsumer.inUse = false + } + } else { + // The released consumer is either not the same one as in the cache, or not in the cache + // at all. This may happen if the cache was invalidate while this consumer was being used. + // Just close this consumer. + internalConsumer.close() + logInfo(s"Released a supposedly cached consumer that was not found in the cache " + + s"$internalConsumer") + } + } +} + +private[kafka010] object InternalKafkaConsumer { + private val UNKNOWN_OFFSET = -2L +} diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 07239eda64d2e..81abc9860bfc3 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -19,8 +19,6 @@ package org.apache.spark.streaming.kafka010 import java.{ util => ju } -import scala.collection.mutable.ArrayBuffer - import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord } import org.apache.kafka.common.TopicPartition @@ -239,26 +237,18 @@ private class KafkaRDDIterator[K, V]( cacheLoadFactor: Float ) extends Iterator[ConsumerRecord[K, V]] { - val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - context.addTaskCompletionListener(_ => closeIfNeeded()) - val consumer = if (useConsumerCache) { - CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber >= 1) { - // just in case the prior attempt failures were cache related - CachedKafkaConsumer.remove(groupId, part.topic, part.partition) - } - CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams) - } else { - CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams) + val consumer = { + KafkaDataConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) + KafkaDataConsumer.acquire[K, V](part.topicPartition(), kafkaParams, context, useConsumerCache) } var requestOffset = part.fromOffset def closeIfNeeded(): Unit = { - if (!useConsumerCache && consumer != null) { - consumer.close() + if (consumer != null) { + consumer.release() } } diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala new file mode 100644 index 0000000000000..d934c64962adb --- /dev/null +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaDataConsumerSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import java.util.concurrent.{Executors, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.util.Random + +import org.apache.kafka.clients.consumer.ConsumerConfig._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark._ + +class KafkaDataConsumerSuite extends SparkFunSuite with BeforeAndAfterAll { + private var testUtils: KafkaTestUtils = _ + private val topic = "topic" + Random.nextInt() + private val topicPartition = new TopicPartition(topic, 0) + private val groupId = "groupId" + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils + testUtils.setup() + KafkaDataConsumer.init(16, 64, 0.75f) + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + } + super.afterAll() + } + + private def getKafkaParams() = Map[String, Object]( + GROUP_ID_CONFIG -> groupId, + BOOTSTRAP_SERVERS_CONFIG -> testUtils.brokerAddress, + KEY_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + VALUE_DESERIALIZER_CLASS_CONFIG -> classOf[ByteArrayDeserializer].getName, + AUTO_OFFSET_RESET_CONFIG -> "earliest", + ENABLE_AUTO_COMMIT_CONFIG -> "false" + ).asJava + + test("KafkaDataConsumer reuse in case of same groupId and TopicPartition") { + KafkaDataConsumer.cache.clear() + + val kafkaParams = getKafkaParams() + + val consumer1 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer1.release() + + val consumer2 = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, null, true) + consumer2.release() + + assert(KafkaDataConsumer.cache.size() == 1) + val key = new CacheKey(groupId, topicPartition) + val existingInternalConsumer = KafkaDataConsumer.cache.get(key) + assert(existingInternalConsumer.eq(consumer1.internalConsumer)) + assert(existingInternalConsumer.eq(consumer2.internalConsumer)) + } + + test("concurrent use of KafkaDataConsumer") { + val data = (1 to 1000).map(_.toString) + testUtils.createTopic(topic) + testUtils.sendMessages(topic, data.toArray) + + val kafkaParams = getKafkaParams() + + val numThreads = 100 + val numConsumerUsages = 500 + + @volatile var error: Throwable = null + + def consume(i: Int): Unit = { + val useCache = Random.nextBoolean + val taskContext = if (Random.nextBoolean) { + new TaskContextImpl(0, 0, 0, 0, attemptNumber = Random.nextInt(2), null, null, null) + } else { + null + } + val consumer = KafkaDataConsumer.acquire[Array[Byte], Array[Byte]]( + topicPartition, kafkaParams, taskContext, useCache) + try { + val rcvd = (0 until data.length).map { offset => + val bytes = consumer.get(offset, 10000).value() + new String(bytes) + } + assert(rcvd == data) + } catch { + case e: Throwable => + error = e + throw e + } finally { + consumer.release() + } + } + + val threadPool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numConsumerUsages).map { i => + threadPool.submit(new Runnable { + override def run(): Unit = { consume(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + assert(error == null) + } finally { + threadPool.shutdown() + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3fb6d1e4e4f3e..337133a2e2326 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -146,12 +146,21 @@ class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. - val oldDataset: RDD[LabeledPoint] = + val convert2LabeledPoint = (dataset: Dataset[_]) => { dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + @@ -159,7 +168,18 @@ class GBTClassifier @Since("1.4.0") ( s" GBTClassifier currently only supports binary classification.") LabeledPoint(label, features) } - val numFeatures = oldDataset.first().features.size + } + + val (trainDataset, validationDataset) = if (withValidation) { + ( + convert2LabeledPoint(dataset.filter(not(col($(validationIndicatorCol))))), + convert2LabeledPoint(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (convert2LabeledPoint(dataset), null) + } + + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -169,15 +189,21 @@ class GBTClassifier @Since("1.4.0") ( s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") } - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, - seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) + seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, + validationIndicatorCol) instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) + } + val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 438e53ba6197c..1ad4e097246a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -261,8 +261,9 @@ class BisectingKMeans @Since("2.0.0") ( transformSchema(dataset.schema, logging = true) val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol) - val instr = Instrumentation.create(this, rdd) - instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize) + val instr = Instrumentation.create(this, dataset) + instr.logParams(featuresCol, predictionCol, k, maxIter, seed, + minDivisibleClusterSize, distanceMeasure) val bkm = new MLlibBisectingKMeans() .setK($(k)) @@ -275,6 +276,8 @@ class BisectingKMeans @Since("2.0.0") ( val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 88d618c3a03a8..3091bb5a2e54c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -352,7 +352,7 @@ class GaussianMixture @Since("2.0.0") ( s"than ${GaussianMixture.MAX_NUM_FEATURES} features because the size of the covariance" + s" matrix is quadratic in the number of features.") - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, probabilityCol, k, maxIter, seed, tol) instr.logNumFeatures(numFeatures) @@ -425,6 +425,9 @@ class GaussianMixture @Since("2.0.0") ( val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood) model.setSummary(Some(summary)) + instr.logNamedValue("logLikelihood", logLikelihood) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 97f246fbfd859..e72d7f9485e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -342,7 +342,7 @@ class KMeans @Since("1.5.0") ( instances.persist(StorageLevel.MEMORY_AND_DISK) } - val instr = Instrumentation.create(this, instances) + val instr = Instrumentation.create(this, dataset) instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure, maxIter, seed, tol) val algo = new MLlibKMeans() @@ -359,6 +359,8 @@ class KMeans @Since("1.5.0") ( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(Some(summary)) + // TODO: need to extend logNamedValue to support Array + instr.logNamedValue("clusterSizes", summary.clusterSizes.mkString("[", ",", "]")) instr.logSuccess(model) if (handlePersistence) { instances.unpersist() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index afe599cd167cb..fed42c959b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -569,10 +569,14 @@ abstract class LDAModel private[ml] ( class LocalLDAModel private[ml] ( uid: String, vocabSize: Int, - @Since("1.6.0") override private[clustering] val oldLocalModel: OldLocalLDAModel, + private[clustering] val oldLocalModel_ : OldLocalLDAModel, sparkSession: SparkSession) extends LDAModel(uid, vocabSize, sparkSession) { + override private[clustering] def oldLocalModel: OldLocalLDAModel = { + oldLocalModel_.setSeed(getSeed) + } + @Since("1.6.0") override def copy(extra: ParamMap): LocalLDAModel = { val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sparkSession) diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index 0bf405d9abf9d..d7fbe28ae7a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -161,6 +161,8 @@ class FPGrowth @Since("2.2.0") ( private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = { val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val instr = Instrumentation.create(this, dataset) + instr.logParams(params: _*) val data = dataset.select($(itemsCol)) val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[Any](0).toArray) val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport)) @@ -183,7 +185,9 @@ class FPGrowth @Since("2.2.0") ( items.unpersist() } - copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + val model = copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this) + instr.logSuccess(model) + model } @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index b9c3170cc3c28..7e08675f834da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -95,7 +95,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" + " and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"), isValid = "(value: String) => " + - "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)") + "org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)"), + ParamDesc[String]("validationIndicatorCol", "name of the column that indicates whether " + + "each row is for training or for validation. False indicates training; true indicates " + + "validation.") ) val code = genSharedParams(params) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 282ea6ebcbf7f..5928a0749f738 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -523,4 +523,21 @@ trait HasDistanceMeasure extends Params { /** @group getParam */ final def getDistanceMeasure: String = $(distanceMeasure) } + +/** + * Trait for shared param validationIndicatorCol. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasValidationIndicatorCol extends Params { + + /** + * Param for name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.. + * @group param + */ + final val validationIndicatorCol: Param[String] = new Param[String](this, "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.") + + /** @group getParam */ + final def getValidationIndicatorCol: String = $(validationIndicatorCol) +} // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index d7e054bf55ef6..eb8b3c001436a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -145,21 +145,42 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + /** @group setParam */ + @Since("2.4.0") + def setValidationIndicatorCol(value: String): this.type = { + set(validationIndicatorCol, value) + } + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) - val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val numFeatures = oldDataset.first().features.size + + val withValidation = isDefined(validationIndicatorCol) && $(validationIndicatorCol).nonEmpty + + val (trainDataset, validationDataset) = if (withValidation) { + ( + extractLabeledPoints(dataset.filter(not(col($(validationIndicatorCol))))), + extractLabeledPoints(dataset.filter(col($(validationIndicatorCol)))) + ) + } else { + (extractLabeledPoints(dataset), null) + } + val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) - val instr = Instrumentation.create(this, oldDataset) + val instr = Instrumentation.create(this, dataset) instr.logParams(labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) instr.logNumFeatures(numFeatures) - val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, - $(seed), $(featureSubsetStrategy)) + val (baseLearners, learnerWeights) = if (withValidation) { + GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } else { + GradientBoostedTrees.run(trainDataset, boostingStrategy, + $(seed), $(featureSubsetStrategy)) + } val m = new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) instr.logSuccess(m) m diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index ec8868bb42cbb..00157fe63af41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.util.Try +import org.apache.spark.annotation.Since import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -460,18 +461,34 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize + with HasValidationIndicatorCol { + + /** + * Threshold for stopping early when fit with validation is used. + * (This parameter is ignored when fit without validation is used.) + * The decision to stop early is decided based on this logic: + * If the current loss on the validation set is greater than 0.01, the diff + * of validation error is compared to relative tolerance which is + * validationTol * (current loss on the validation set). + * If the current loss on the validation set is less than or equal to 0.01, + * the diff of validation error is compared to absolute tolerance which is + * validationTol * 0.01. * @group param + * @see validationIndicatorCol */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 + @Since("2.4.0") + final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", + "Threshold for stopping early when fit with validation is used." + + "If the error rate on the validation input changes by less than the validationTol," + + "then learning will stop early (before `maxIter`)." + + "This parameter is ignored when fit without validation is used.", + ParamValidators.gtEq(0.0) + ) + + /** @group getParam */ + @Since("2.4.0") + final def getValidationTol: Double = $(validationTol) /** * @deprecated This method is deprecated and will be removed in 3.0.0. @@ -497,7 +514,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") def setStepSize(value: Double): this.type = set(stepSize, value) - setDefault(maxIter -> 20, stepSize -> 0.1) + setDefault(maxIter -> 20, stepSize -> 0.1, validationTol -> 0.01) setDefault(featureSubsetStrategy -> "all") @@ -507,7 +524,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize, getValidationTol) } /** Get old Gradient Boosting Loss type */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 5e916cc4a9fdd..f327f37bad204 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -144,7 +144,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() - logDebug(s"Train split $splitIndex with multiple sets of parameters.") + instr.logDebug(s"Train split $splitIndex with multiple sets of parameters.") // Fit models in a Future for training in parallel val foldMetricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => @@ -155,7 +155,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -169,12 +169,12 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) foldMetrics }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits - logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + instr.logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best cross-validation metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new CrossValidatorModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 13369c4df7180..14d6a69c36747 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,7 +143,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } else None // Fit models in a Future for training in parallel - logDebug(s"Train split with multiple sets of parameters.") + instr.logDebug(s"Train split with multiple sets of parameters.") val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] @@ -153,7 +153,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St } // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) - logDebug(s"Got metric $metric for model trained with $paramMap.") + instr.logDebug(s"Got metric $metric for model trained with $paramMap.") metric } (executionContext) } @@ -165,12 +165,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St trainingDataset.unpersist() validationDataset.unpersist() - logInfo(s"Train validation split metrics: ${metrics.toSeq}") + instr.logInfo(s"Train validation split metrics: ${metrics.toSeq}") val (bestMetric, bestIndex) = if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) else metrics.zipWithIndex.minBy(_._1) - logInfo(s"Best set of parameters:\n${epm(bestIndex)}") - logInfo(s"Best train validation split metric: $bestMetric.") + instr.logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + instr.logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] instr.logSuccess(bestModel) copyValues(new TrainValidationSplitModel(uid, bestModel, metrics) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 3247c394dfa64..467130b37c16e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -58,6 +58,13 @@ private[spark] class Instrumentation[E <: Estimator[_]] private ( s" storageLevel=${dataset.getStorageLevel}") } + /** + * Logs a debug message with a prefix that uniquely identifies the training session. + */ + override def logDebug(msg: => String): Unit = { + super.logDebug(prefix + msg) + } + /** * Logs a warning message with a prefix that uniquely identifies the training session. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b8a6e94248421..f915062d77389 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** * Latent Dirichlet Allocation (LDA) model. @@ -194,6 +194,8 @@ class LocalLDAModel private[spark] ( override protected[spark] val gammaShape: Double = 100) extends LDAModel with Serializable { + private var seed: Long = Utils.random.nextLong() + @Since("1.3.0") override def k: Int = topics.numCols @@ -216,6 +218,21 @@ class LocalLDAModel private[spark] ( override protected def formatVersion = "1.0" + /** + * Random seed for cluster initialization. + */ + @Since("2.4.0") + def getSeed: Long = seed + + /** + * Set the random seed for cluster initialization. + */ + @Since("2.4.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, @@ -298,6 +315,7 @@ class LocalLDAModel private[spark] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t val ElogbetaBc = documents.sparkContext.broadcast(Elogbeta) + val gammaSeed = this.seed // Sum bound components for each document: // component for prob(tokens) + component for prob(document-topic distribution) @@ -306,7 +324,7 @@ class LocalLDAModel private[spark] ( val localElogbeta = ElogbetaBc.value var docBound = 0.0D val (gammad: BDV[Double], _, _) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, exp(localElogbeta), brzAlpha, gammaShape, k) + termCounts, exp(localElogbeta), brzAlpha, gammaShape, k, gammaSeed + id) val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) // E[log p(doc | theta, beta)] @@ -352,6 +370,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed documents.map { case (id: Long, termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -362,7 +381,8 @@ class LocalLDAModel private[spark] ( expElogbetaBc.value, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed + id) (id, Vectors.dense(normalize(gamma, 1.0).toArray)) } } @@ -376,6 +396,7 @@ class LocalLDAModel private[spark] ( val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape val k = this.k + val gammaSeed = this.seed (termCounts: Vector) => if (termCounts.numNonzeros == 0) { @@ -386,7 +407,8 @@ class LocalLDAModel private[spark] ( expElogbeta, docConcentrationBrz, gammaShape, - k) + k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } @@ -403,6 +425,7 @@ class LocalLDAModel private[spark] ( */ @Since("2.0.0") def topicDistribution(document: Vector): Vector = { + val gammaSeed = this.seed val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) if (document.numNonzeros == 0) { Vectors.zeros(this.k) @@ -412,7 +435,8 @@ class LocalLDAModel private[spark] ( expElogbeta, this.docConcentration.asBreeze, gammaShape, - this.k) + this.k, + gammaSeed) Vectors.dense(normalize(gamma, 1.0).toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 693a2a31f026b..f8e5f3ed76457 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -464,6 +465,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape val optimizeDocConcentration = this.optimizeDocConcentration + val seed = randomGenerator.nextLong() // If and only if optimizeDocConcentration is set true, // we calculate logphat in the same pass as other statistics. // No calculation of loghat happens otherwise. @@ -473,20 +475,21 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { None } - val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) - - val stat = BDM.zeros[Double](k, vocabSize) - val logphatPartOption = logphatPartOptionBase() - var nonEmptyDocCount: Long = 0L - nonEmptyDocs.foreach { case (_, termCounts: Vector) => - nonEmptyDocCount += 1 - val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( - termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids) + sstats - logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) - } - Iterator((stat, logphatPartOption, nonEmptyDocCount)) + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitionsWithIndex { + (index, docs) => + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) + + val stat = BDM.zeros[Double](k, vocabSize) + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L + nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 + val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbetaBc.value, alpha, gammaShape, k, seed + index) + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) + } + Iterator((stat, logphatPartOption, nonEmptyDocCount)) } val elementWiseSum = ( @@ -578,7 +581,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer with Logging { } override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta) + .setSeed(randomGenerator.nextLong()) } } @@ -605,18 +609,20 @@ private[clustering] object OnlineLDAOptimizer { expElogbeta: BDM[Double], alpha: breeze.linalg.Vector[Double], gammaShape: Double, - k: Int): (BDV[Double], BDM[Double], List[Int]) = { + k: Int, + seed: Long): (BDV[Double], BDM[Double], List[Int]) = { val (ids: List[Int], cts: Array[Double]) = termCounts match { case v: DenseVector => ((0 until v.size).toList, v.values) case v: SparseVector => (v.indices.toList, v.values) } // Initialize the variational distribution q(theta|gamma) for the mini-batch + val randBasis = new RandBasis(new org.apache.commons.math3.random.MersenneTwister(seed)) val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + new Gamma(gammaShape, 1.0 / gammaShape)(randBasis).samplesVector(k) // K val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 3ca75e8cdb97a..7a5e520d5818e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.random.XORShiftRandom * $$ * \begin{align} * c_t+1 &= [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] \\ - * n_t+t &= n_t * a + m_t + * n_t+1 &= n_t * a + m_t * \end{align} * $$ * diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index e20de196d65ca..e6d2a8e2b900e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.LogLoss import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -392,6 +393,51 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { assert(evalArr(2) ~== lossErr3 relTol 1E-3) } + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTClassifier.supportedLossTypes) { + val gbt = new GBTClassifier() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val (errorWithoutValidation, errorWithValidation) = { + val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType), + GradientBoostedTrees.computeError(remappedRdd, modelWithValidation.trees, + modelWithValidation.treeWeights, modelWithValidation.getOldLossType)) + } + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Classification) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f3ff2afcad2cd..81842afbddbbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -19,17 +19,18 @@ package org.apache.spark.ml.clustering import scala.language.existentials -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.clustering.DistanceMeasure -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.Dataset -class BisectingKMeansSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + +class BisectingKMeansSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -68,10 +69,13 @@ class BisectingKMeansSuite // Verify fit does not fail on very sparse data val model = bkm.fit(sparseDataset) - val result = model.transform(sparseDataset) - val numClusters = result.select("prediction").distinct().collect().length - // Verify we hit the edge case - assert(numClusters < k && numClusters > 1) + + testTransformerByGlobalCheckFunc[Tuple1[Vector]](sparseDataset.toDF(), model, "prediction") { + rows => + val numClusters = rows.distinct.length + // Verify we hit the edge case + assert(numClusters < k && numClusters > 1) + } } test("setter/getter") { @@ -104,19 +108,16 @@ class BisectingKMeansSuite val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = bkm.fit(dataset) assert(model.clusterCenters.length === k) - - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) + } + // Check validity of model summary val numRows = dataset.count() assert(model.hasSummary) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index d0d461a42711a..0b91f502f615b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -23,16 +23,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Vector, Vectors} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.stat.distribution.MultivariateGaussian -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} -class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext - with DefaultReadWriteTest { - import testImplicits._ +class GaussianMixtureSuite extends MLTest with DefaultReadWriteTest { + import GaussianMixtureSuite._ + import testImplicits._ final val k = 5 private val seed = 538009335 @@ -119,15 +118,10 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.weights.length === k) assert(model.gaussians.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName, probabilityColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - // Check prediction matches the highest probability, and probabilities sum to one. - transformed.select(predictionColName, probabilityColName).collect().foreach { - case Row(pred: Int, prob: Vector) => + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName, probabilityColName) { + case Row(_, pred: Int, prob: Vector) => val probArray = prob.toArray val predFromProb = probArray.zipWithIndex.maxBy(_._1)._2 assert(pred === predFromProb) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 680a7c2034083..2569e7a432ca4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,20 +22,21 @@ import scala.util.Random import org.dmg.pmml.{ClusteringModel, PMML} -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} +import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, + KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vectors => MLlibVectors} -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} private[clustering] case class TestRow(features: Vector) -class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest - with PMMLReadWriteTest { +class KMeansSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { + + import testImplicits._ final val k = 5 @transient var dataset: Dataset[_] = _ @@ -109,15 +110,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR val model = kmeans.fit(dataset) assert(model.clusterCenters.length === k) - val transformed = model.transform(dataset) - val expectedColumns = Array("features", predictionColName) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) + testTransformerByGlobalCheckFunc[Tuple1[Vector]](dataset.toDF(), model, + "features", predictionColName) { rows => + val clusters = rows.map(_.getAs[Int](predictionColName)).toSet + assert(clusters.size === k) + assert(clusters === Set(0, 1, 2, 3, 4)) } - val clusters = - transformed.select(predictionColName).rdd.map(_.getInt(0)).distinct().collect().toSet - assert(clusters.size === k) - assert(clusters === Set(0, 1, 2, 3, 4)) + assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) @@ -149,9 +148,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName) val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName)) - Seq(featuresColName, predictionColName).foreach { column => - assert(transformed.columns.contains(column)) - } + assert(transformed.schema.fieldNames.toSet === Set(featuresColName, predictionColName)) assert(model.getFeaturesCol == featuresColName) assert(model.getPredictionCol == predictionColName) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 8d728f063dd8c..096b5416899e1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -21,11 +21,9 @@ import scala.language.existentials import org.apache.hadoop.fs.Path -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ object LDASuite { @@ -61,7 +59,7 @@ object LDASuite { } -class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class LDASuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -186,16 +184,11 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.topicsMatrix.numCols === k) assert(!model.isDistributed) - // transform() - val transformed = model.transform(dataset) - val expectedColumns = Array("features", lda.getTopicDistributionCol) - expectedColumns.foreach { column => - assert(transformed.columns.contains(column)) - } - transformed.select(lda.getTopicDistributionCol).collect().foreach { r => - val topicDistribution = r.getAs[Vector](0) - assert(topicDistribution.size === k) - assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) + testTransformer[Tuple1[Vector]](dataset.toDF(), model, + "features", lda.getTopicDistributionCol) { + case Row(_, topicDistribution: Vector) => + assert(topicDistribution.size === k) + assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0)) } // logLikelihood, logPerplexity @@ -253,6 +246,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, LDASuite.allParamSettings, checkModelData) + + // Make sure the result is deterministic after saving and loading the model + val model = lda.fit(dataset) + val model2 = testDefaultReadWrite(model) + assert(model.logLikelihood(dataset) ~== model2.logLikelihood(dataset) absTol 1e-6) + assert(model.logPerplexity(dataset) ~== model2.logPerplexity(dataset) absTol 1e-6) } test("read/write DistributedLDAModel") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 773f6d2c542fe..b145c7a3dc952 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.lit import org.apache.spark.util.Utils /** @@ -231,7 +232,52 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { } } - ///////////////////////////////////////////////////////////////////////////// + test("runWithValidation stops early and performs better on a validation dataset") { + val validationIndicatorCol = "validationIndicator" + val trainDF = trainData.toDF().withColumn(validationIndicatorCol, lit(false)) + val validationDF = validationData.toDF().withColumn(validationIndicatorCol, lit(true)) + + val numIter = 20 + for (lossType <- GBTRegressor.supportedLossTypes) { + val gbt = new GBTRegressor() + .setSeed(123) + .setMaxDepth(2) + .setLossType(lossType) + .setMaxIter(numIter) + val modelWithoutValidation = gbt.fit(trainDF) + + gbt.setValidationIndicatorCol(validationIndicatorCol) + val modelWithValidation = gbt.fit(trainDF.union(validationDF)) + + assert(modelWithoutValidation.numTrees === numIter) + // early stop + assert(modelWithValidation.numTrees < numIter) + + val errorWithoutValidation = GradientBoostedTrees.computeError(validationData, + modelWithoutValidation.trees, modelWithoutValidation.treeWeights, + modelWithoutValidation.getOldLossType) + val errorWithValidation = GradientBoostedTrees.computeError(validationData, + modelWithValidation.trees, modelWithValidation.treeWeights, + modelWithValidation.getOldLossType) + + assert(errorWithValidation < errorWithoutValidation) + + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validationData, modelWithoutValidation.trees, + modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType, + OldAlgo.Regression) + assert(evaluationArray.length === numIter) + assert(evaluationArray(modelWithValidation.numTrees) > + evaluationArray(modelWithValidation.numTrees - 1)) + var i = 1 + while (i < modelWithValidation.numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 7d0e88ee20c3f..4f6d5ff898681 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,11 @@ object MimaExcludes { // Exclude rules for 2.4.x lazy val v24excludes = v23excludes ++ Seq( + // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), + // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), @@ -73,7 +78,18 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.InternalNode"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.tree.Node"), ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.this"), + + // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol") ) // Exclude rules for 2.3.x diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index ea845b98b3db2..88519d7311fcc 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -272,7 +272,7 @@ def save_memoryview(self, obj): if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_unsupported(self, obj): raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) @@ -801,10 +801,10 @@ def save_ellipsis(self, obj): def save_not_implemented(self, obj): self.save_reduce(_gen_not_implemented, ()) - if PY3: - dispatch[io.TextIOWrapper] = save_file - else: + try: # Python 2 dispatch[file] = save_file + except NameError: # Python 3 + dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis dispatch[type(NotImplemented)] = save_not_implemented @@ -819,6 +819,11 @@ def save_logger(self, obj): dispatch[logging.Logger] = save_logger + def save_root_logger(self, obj): + self.save_reduce(logging.getLogger, (), obj=obj) + + dispatch[logging.RootLogger] = save_root_logger + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" diff --git a/python/pyspark/context.py b/python/pyspark/context.py index dbb463f6005a1..ede3b6af0a8cf 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,9 +211,21 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: - self._python_includes.append(filename) - sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) + try: + filepath = os.path.join(SparkFiles.getRootDirectory(), filename) + if not os.path.exists(filepath): + # In case of YARN with shell mode, 'spark.submit.pyFiles' files are + # not added via SparkContext.addFile. Here we check if the file exists, + # try to copy and then add it to the path. See SPARK-21945. + shutil.copyfile(path, filepath) + if filename[-4:].lower() in self.PACKAGE_EXTENSIONS: + self._python_includes.append(filename) + sys.path.insert(1, filepath) + except Exception: + warnings.warn( + "Failed to add file [%s] speficied in 'spark.submit.pyFiles' to " + "Python path:\n %s" % (path, "\n ".join(sys.path)), + RuntimeWarning) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ec17653a1adf9..424ecfd89b060 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1222,6 +1222,10 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)], + ... ["indexed", "features"]) + >>> model.evaluateEachIteration(validation) + [0.25..., 0.23..., 0.21..., 0.19..., 0.18...] .. versionadded:: 1.4.0 """ @@ -1319,6 +1323,17 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + return self._call_java("evaluateEachIteration", dataset) + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 9a66d87d7f211..dd0b62f184d26 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1056,6 +1056,10 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, True >>> model.trees [DecisionTreeRegressionModel (uid=...) of depth..., DecisionTreeRegressionModel...] + >>> validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0))], + ... ["label", "features"]) + >>> model.evaluateEachIteration(validation, "squared") + [0.0, 0.0, 0.0, 0.0, 0.0] .. versionadded:: 1.4.0 """ @@ -1156,6 +1160,20 @@ def trees(self): """Trees in this ensemble. Warning: These have null parent Estimators.""" return [DecisionTreeRegressionModel(m) for m in list(self._call_java("trees"))] + @since("2.4.0") + def evaluateEachIteration(self, dataset, loss): + """ + Method to compute error or loss for every iteration of gradient boosting. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + :param loss: + The loss function used to compute error. + Supported options: squared, absolute + """ + return self._call_java("evaluateEachIteration", dataset, loss) + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 093593132e56d..0dde0db9e3339 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1595,6 +1595,44 @@ def test_default_read_write(self): self.assertEqual(lr.uid, lr3.uid) self.assertEqual(lr.extractParamMap(), lr3.extractParamMap()) + def test_default_read_write_default_params(self): + lr = LogisticRegression() + self.assertFalse(lr.isSet(lr.getParam("threshold"))) + + lr.setMaxIter(50) + lr.setThreshold(.75) + + # `threshold` is set by user, default param `predictionCol` is not set by user. + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + writer = DefaultParamsWriter(lr) + metadata = json.loads(writer._get_metadata_to_save(lr, self.sc)) + self.assertTrue("defaultParamMap" in metadata) + + reader = DefaultParamsReadable.read() + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + + self.assertTrue(lr.isSet(lr.getParam("threshold"))) + self.assertFalse(lr.isSet(lr.getParam("predictionCol"))) + self.assertTrue(lr.hasDefault(lr.getParam("predictionCol"))) + + # manually create metadata without `defaultParamMap` section. + del metadata['defaultParamMap'] + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"): + reader.getAndSetParams(lr, loadedMetadata) + + # Prior to 2.4.0, metadata doesn't have `defaultParamMap`. + metadata['sparkVersion'] = '2.3.0' + metadataStr = json.dumps(metadata, separators=[',', ':']) + loadedMetadata = reader._parseMetaData(metadataStr, ) + reader.getAndSetParams(lr, loadedMetadata) + class LDATest(SparkSessionTestCase): diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index a486c6a3fdeb5..9fa85664939b8 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -30,6 +30,7 @@ from pyspark import SparkContext, since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession +from pyspark.util import VersionUtils def _jvm(): @@ -396,6 +397,7 @@ def saveMetadata(instance, path, sc, extraMetadata=None, paramMap=None): - sparkVersion - uid - paramMap + - defaultParamMap (since 2.4.0) - (optionally, extra metadata) :param extraMetadata: Extra metadata to be saved at same level as uid, paramMap, etc. :param paramMap: If given, this is saved in the "paramMap" field. @@ -417,15 +419,24 @@ def _get_metadata_to_save(instance, sc, extraMetadata=None, paramMap=None): """ uid = instance.uid cls = instance.__module__ + '.' + instance.__class__.__name__ - params = instance.extractParamMap() + + # User-supplied param values + params = instance._paramMap jsonParams = {} if paramMap is not None: jsonParams = paramMap else: for p in params: jsonParams[p.name] = params[p] + + # Default param values + jsonDefaultParams = {} + for p in instance._defaultParamMap: + jsonDefaultParams[p.name] = instance._defaultParamMap[p] + basicMetadata = {"class": cls, "timestamp": long(round(time.time() * 1000)), - "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams} + "sparkVersion": sc.version, "uid": uid, "paramMap": jsonParams, + "defaultParamMap": jsonDefaultParams} if extraMetadata is not None: basicMetadata.update(extraMetadata) return json.dumps(basicMetadata, separators=[',', ':']) @@ -523,11 +534,26 @@ def getAndSetParams(instance, metadata): """ Extract Params from metadata, and set them in the instance. """ + # Set user-supplied param values for paramName in metadata['paramMap']: param = instance.getParam(paramName) paramValue = metadata['paramMap'][paramName] instance.set(param, paramValue) + # Set default param values + majorAndMinorVersions = VersionUtils.majorMinorVersion(metadata['sparkVersion']) + major = majorAndMinorVersions[0] + minor = majorAndMinorVersions[1] + + # For metadata file prior to Spark 2.4, there is no default section. + if major > 2 or (major == 2 and minor >= 4): + assert 'defaultParamMap' in metadata, "Error loading metadata: Expected " + \ + "`defaultParamMap` section not found" + + for paramName in metadata['defaultParamMap']: + paramValue = metadata['defaultParamMap'][paramName] + instance._setDefault(**{paramName: paramValue}) + @staticmethod def loadParamsInstance(path, sc): """ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5a584152b4f6..fbc8a2d038f8f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1108,8 +1108,11 @@ def add_months(start, months): @since(1.5) def months_between(date1, date2, roundOff=True): """ - Returns the number of months between date1 and date2. - Unless `roundOff` is set to `False`, the result is rounded off to 8 digits. + Returns number of months between dates date1 and date2. + If date1 is later than date2, then the result is positive. + If date1 and date2 are on the same day of month, or both are the last day of month, + returns an integer (time of day will be ignored). + The result is rounded off to 8 digits unless `roundOff` is set to `False`. >>> df = spark.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['date1', 'date2']) >>> df.select(months_between(df.date1, df.date2).alias('months')).collect() @@ -1852,6 +1855,21 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(2.4) +def arrays_overlap(a1, a2): + """ + Collection function: returns true if the arrays contain any common non-null element; if not, + returns null if both the arrays are non-empty and any of them contains a null element; returns + false otherwise. + + >>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y']) + >>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect() + [Row(overlap=True), Row(overlap=False)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2))) + + @since(2.4) def slice(x, start, length): """ @@ -2092,12 +2110,13 @@ def json_tuple(col, *fields): return Column(jc) +@ignore_unicode_prefix @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` - of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an - unparseable string. + Parses a column containing a JSON string into a :class:`MapType` with :class:`StringType` + as keys type, :class:`StructType` or :class:`ArrayType` of :class:`StructType`\\s with + the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -2114,6 +2133,9 @@ def from_json(col, schema, options={}): [Row(json=Row(a=1))] >>> df.select(from_json(df.value, "a INT").alias("json")).collect() [Row(json=Row(a=1))] + >>> schema = MapType(StringType(), IntegerType()) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json={u'a': 1})] >>> data = [(1, '''[{"a": 1}]''')] >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) >>> df = spark.createDataFrame(data, ("key", "value")) @@ -2322,6 +2344,40 @@ def map_values(col): return Column(sc._jvm.functions.map_values(_to_java_column(col))) +@since(2.4) +def map_entries(col): + """ + Collection function: Returns an unordered array of all entries in the given map. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_entries + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data") + >>> df.select(map_entries("data").alias("entries")).show() + +----------------+ + | entries| + +----------------+ + |[[1, a], [2, b]]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_entries(_to_java_column(col))) + + +@ignore_unicode_prefix +@since(2.4) +def array_repeat(col, count): + """ + Collection function: creates an array containing a column repeated count times. + + >>> df = spark.createDataFrame([('ab',)], ['data']) + >>> df.select(array_repeat(df.data, 3).alias('r')).collect() + [Row(r=[u'ab', u'ab', u'ab'])] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 16aa9378ad8ee..c7bd8f01b907f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4680,6 +4680,26 @@ def test_supported_types(self): self.assertPandasEqual(expected2, result2) self.assertPandasEqual(expected3, result3) + def test_array_type_correct(self): + from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col + + df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id") + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType()))]) + + udf = pandas_udf( + lambda pdf: pdf, + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result = df.groupby('id').apply(udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True) + self.assertPandasEqual(expected, result) + def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -5219,8 +5239,8 @@ def test_complex_groupby(self): expected2 = df.groupby().agg(sum(df.v)) # groupby one column and one sql expression - result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)) - expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)) + result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2) + expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2) # groupby one python UDF result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v)) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8bb63fcc7ff9c..5d2e58bef6466 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -82,7 +82,7 @@ def wrap_scalar_pandas_udf(f, return_type): def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of the user-defined functon should be " + raise TypeError("Return type of the user-defined function should be " "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index 022191d0070fd..91f64141e5318 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -39,7 +39,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")

    Cannot find driver {driverId}

    - return UIUtils.basicSparkPage(content, s"Details for Job $driverId") + return UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } val driverState = state.get val driverHeaders = Seq("Driver property", "Value") @@ -68,7 +68,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") retryHeaders, retryRow, Iterable.apply(driverState.description.retryState)) val content =

    Driver state information for driver id {driverId}

    - Back to Drivers + Back to Drivers

    Driver state: {driverState.state}

    @@ -87,7 +87,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
    ; - UIUtils.basicSparkPage(content, s"Details for Job $driverId") + UIUtils.basicSparkPage(request, content, s"Details for Job $driverId") } private def launchedRow(submissionState: Option[MesosClusterSubmissionState]): Seq[Node] = { diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 88a6614d51384..c53285331ea68 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -62,7 +62,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {retryTable} ; - UIUtils.basicSparkPage(content, "Spark Drivers for Mesos cluster") + UIUtils.basicSparkPage(request, content, "Spark Drivers for Mesos cluster") } private def queuedRow(submission: MesosDriverDescription): Seq[Node] = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 8eda6cb1277c5..7250e58b6c49a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -200,7 +200,29 @@ object YarnSparkHadoopUtil { .map(new Path(_).getFileSystem(hadoopConf)) .getOrElse(FileSystem.get(hadoopConf)) - filesystemsToAccess + stagingFS + // Add the list of available namenodes for all namespaces in HDFS federation. + // If ViewFS is enabled, this is skipped as ViewFS already handles delegation tokens for its + // namespaces. + val hadoopFilesystems = if (stagingFS.getScheme == "viewfs") { + Set.empty + } else { + val nameservices = hadoopConf.getTrimmedStrings("dfs.nameservices") + // Retrieving the filesystem for the nameservices where HA is not enabled + val filesystemsWithoutHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.namenode.rpc-address.$ns")).map { nameNode => + new Path(s"hdfs://$nameNode").getFileSystem(hadoopConf) + } + } + // Retrieving the filesystem for the nameservices where HA is enabled + val filesystemsWithHA = nameservices.flatMap { ns => + Option(hadoopConf.get(s"dfs.ha.namenodes.$ns")).map { _ => + new Path(s"hdfs://$ns").getFileSystem(hadoopConf) + } + } + (filesystemsWithoutHA ++ filesystemsWithHA).toSet + } + + filesystemsToAccess ++ hadoopFilesystems + stagingFS } } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index f21353aa007c8..61c0c43f7c04f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -21,7 +21,8 @@ import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.{ByteStreams, Files} -import org.apache.hadoop.io.Text +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers @@ -141,4 +142,66 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging } + test("SPARK-24149: retrieve all namenodes from HDFS") { + val sparkConf = new SparkConf() + val basicFederationConf = new Configuration() + basicFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + basicFederationConf.set("dfs.nameservices", "ns1,ns2") + basicFederationConf.set("dfs.namenode.rpc-address.ns1", "localhost:8020") + basicFederationConf.set("dfs.namenode.rpc-address.ns2", "localhost:8021") + val basicFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(basicFederationConf), + new Path("hdfs://localhost:8021").getFileSystem(basicFederationConf)) + val basicFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, basicFederationConf) + basicFederationResult should be (basicFederationExpected) + + // when viewfs is enabled, namespaces are handled by it, so we don't need to take care of them + val viewFsConf = new Configuration() + viewFsConf.addResource(basicFederationConf) + viewFsConf.set("fs.defaultFS", "viewfs://clusterX/") + viewFsConf.set("fs.viewfs.mounttable.clusterX.link./home", "hdfs://localhost:8020/") + val viewFsExpected = Set(new Path("viewfs://clusterX/").getFileSystem(viewFsConf)) + YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, viewFsConf) should be (viewFsExpected) + + // invalid config should not throw NullPointerException + val invalidFederationConf = new Configuration() + invalidFederationConf.addResource(basicFederationConf) + invalidFederationConf.unset("dfs.namenode.rpc-address.ns2") + val invalidFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(invalidFederationConf)) + val invalidFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, invalidFederationConf) + invalidFederationResult should be (invalidFederationExpected) + + // no namespaces defined, ie. old case + val noFederationConf = new Configuration() + noFederationConf.set("fs.defaultFS", "hdfs://localhost:8020") + val noFederationExpected = Set( + new Path("hdfs://localhost:8020").getFileSystem(noFederationConf)) + val noFederationResult = YarnSparkHadoopUtil.hadoopFSsToAccess(sparkConf, noFederationConf) + noFederationResult should be (noFederationExpected) + + // federation and HA enabled + val federationAndHAConf = new Configuration() + federationAndHAConf.set("fs.defaultFS", "hdfs://clusterXHA") + federationAndHAConf.set("dfs.nameservices", "clusterXHA,clusterYHA") + federationAndHAConf.set("dfs.ha.namenodes.clusterXHA", "x-nn1,x-nn2") + federationAndHAConf.set("dfs.ha.namenodes.clusterYHA", "y-nn1,y-nn2") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn1", "localhost:8020") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterXHA.x-nn2", "localhost:8021") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn1", "localhost:8022") + federationAndHAConf.set("dfs.namenode.rpc-address.clusterYHA.y-nn2", "localhost:8023") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterXHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + federationAndHAConf.set("dfs.client.failover.proxy.provider.clusterYHA", + "org.apache.hadoop.hdfs.server.namenode.ha.ConfiguredFailoverProxyProvider") + + val federationAndHAExpected = Set( + new Path("hdfs://clusterXHA").getFileSystem(federationAndHAConf), + new Path("hdfs://clusterYHA").getFileSystem(federationAndHAConf)) + val federationAndHAResult = YarnSparkHadoopUtil.hadoopFSsToAccess( + sparkConf, federationAndHAConf) + federationAndHAResult should be (federationAndHAExpected) + } } diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f7f921ec22c35..7c54851097af3 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -398,7 +398,7 @@ hintStatement ; fromClause - : FROM relation (',' relation)* (pivotClause | lateralView*)? + : FROM relation (',' relation)* lateralView* pivotClause? ; aggregation diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 29a1411241cf6..469b0e60cc9a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -62,6 +62,8 @@ */ public final class UnsafeRow extends InternalRow implements Externalizable, KryoSerializable { + public static final int WORD_SIZE = 8; + ////////////////////////////////////////////////////////////////////////////// // Static methods ////////////////////////////////////////////////////////////////////////////// diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index ccdb6bc5d4b7c..7b02317b8538f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -68,10 +68,10 @@ import org.apache.spark.sql.types._ */ @Experimental @InterfaceStability.Evolving -@implicitNotFound("Unable to find encoder for type stored in a Dataset. Primitive types " + - "(Int, String, etc) and Product types (case classes) are supported by importing " + - "spark.implicits._ Support for serializing other types will be added in future " + - "releases.") +@implicitNotFound("Unable to find encoder for type ${T}. An implicit Encoder[${T}] is needed to " + + "store ${T} instances in a Dataset. Primitive types (Int, String, etc) and Product types (case " + + "classes) are supported by importing spark.implicits._ Support for serializing other types " + + "will be added in future releases.") trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dfdcdbc1eb2c7..3eaa9ecf5d075 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -676,13 +676,13 @@ class Analyzer( try { catalog.lookupRelation(tableIdentWithDb) } catch { - case _: NoSuchTableException => - u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}") + case e: NoSuchTableException => + u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e) // If the database is defined and that database is not found, throw an AnalysisException. // Note that if the database is not defined, it is possible we are looking up a temp view. case e: NoSuchDatabaseException => u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " + - s"database ${e.db} doesn't exist.") + s"database ${e.db} doesn't exist.", e) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 94b0561529e71..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - val widerType = TypeCoercion.findWiderTypeForTwo( - dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) - if (widerType.isEmpty) { + if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f9fde53be22ae..23a4a440fac23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -410,6 +410,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArraysOverlap]("arrays_overlap"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySort]("array_sort"), @@ -418,6 +419,7 @@ object FunctionRegistry { expression[ElementAt]("element_at"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), + expression[MapEntries]("map_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), @@ -427,6 +429,7 @@ object FunctionRegistry { expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), + expression[ArrayRepeat]("array_repeat"), CreateStruct.registryEntry, // mask functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 4eb6e642b1c37..71ed75454cd4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( - inputTypes, conf.caseSensitiveAnalysis) - val tpe = wideType.getOrElse { + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) @@ -105,7 +103,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas castedExpr.eval() } catch { case NonFatal(ex) => - table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") + table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex) } }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a7ba201509b78..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes(conf) :: + WidenSetOperationTypes :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion(conf) :: + FunctionArgumentConversion :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion(conf) :: - IfCoercion(conf) :: + CaseWhenCoercion :: + IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts(conf) :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,10 +83,7 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - def findTightestCommonType( - left: DataType, - right: DataType, - caseSensitive: Boolean): Option[DataType] = (left, right) match { + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -105,32 +102,22 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => - val isSameType = if (caseSensitive) { - DataType.equalsIgnoreNullability(t1, t2) - } else { - DataType.equalsIgnoreCaseAndNullability(t1, t2) - } - - if (isSameType) { - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since t1 is same type of t2, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - } else { - None - } + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2, caseSensitive) - val valueType = findTightestCommonType(vt1, vt2, caseSensitive) + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -185,14 +172,13 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2, caseSensitive) - .map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -207,8 +193,7 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -216,7 +201,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeForTwo(d, c) case _ => None }) } @@ -228,22 +213,20 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType, - caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) + findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } @@ -296,32 +279,29 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes( - children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -336,19 +316,18 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType], - caseSensitive: Boolean): Seq[DataType] = { + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) + getWidestTypes(children, attrIndex + 1, castedTypes) } } @@ -453,7 +432,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) + .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -482,7 +461,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { + findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -536,7 +515,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { + object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -544,7 +523,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -552,7 +531,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -563,7 +542,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -573,7 +552,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -601,7 +580,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -611,14 +590,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -658,11 +637,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { + object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -689,17 +668,16 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { + object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -798,11 +776,12 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -825,18 +804,17 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d3d6c636c4ba8..2bed41672fe33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index 7731336d247db..354a3fa0602a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -41,6 +41,11 @@ package object analysis { def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg, t.origin.line, t.origin.startPosition) } + + /** Fails the analysis at the point where a specific tree node was parsed. */ + def failAnalysis(msg: String, cause: Throwable): Nothing = { + throw new AnalysisException(msg, t.origin.line, t.origin.startPosition, cause = Some(cause)) + } } /** Catches any AnalysisExceptions thrown by `f` and attaches `t`'s position if any. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b27d9eb0..df3ab05e02c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bfa55ab9..699ea53b5df0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + + ev.copy(code = + code""" + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6ae88299..9b9fa41a47d0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 9f0779642271d..f1da592a76845 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258469a97..3e7ca88249737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f714905..2ce9d072c71c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf5e81de..9856b37e53fbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a3601c1730..84e38a8b2711e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d4e322d23b95b..fe91e520169b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -220,30 +221,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", - examples = """ - Examples: - > SELECT 3 _FUNC_ 2; - 1.5 - > SELECT 2L _FUNC_ 2L; - 1.0 - """) -// scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) +// Common base trait for Divide and Remainder, since these two classes are almost identical +trait DivModLike extends BinaryArithmetic { - override def symbol: String = "/" - override def decimalMethod: String = "$div" override def nullable: Boolean = true - private lazy val div: (Any, Any) => Any = dataType match { - case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - } - - override def eval(input: InternalRow): Any = { + final override def eval(input: InternalRow): Any = { val input2 = right.eval(input) if (input2 == null || input2 == 0) { null @@ -252,13 +235,15 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (input1 == null) { null } else { - div(input1, input2) + evalOperation(input1, input2) } } } + def evalOperation(left: Any, right: Any): Any + /** - * Special case handling due to division by 0 => null. + * Special case handling due to division/remainder by 0 => null. */ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) @@ -269,13 +254,13 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic s"${eval2.value} == 0" } val javaType = CodeGenerator.javaType(dataType) - val divide = if (dataType.isInstanceOf[DecimalType]) { + val operation = if (dataType.isInstanceOf[DecimalType]) { s"${eval1.value}.$decimalMethod(${eval2.value})" } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -283,10 +268,10 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic ${ev.isNull} = true; } else { ${eval1.code} - ${ev.value} = $divide; + ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -297,13 +282,38 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull}) { ${ev.isNull} = true; } else { - ${ev.value} = $divide; + ${ev.value} = $operation; } }""") } } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1.5 + > SELECT 2L _FUNC_ 2L; + 1.0 + """) +// scalastyle:on line.size.limit +case class Divide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" + override def decimalMethod: String = "$div" + + private lazy val div: (Any, Any) => Any = dataType match { + case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ @@ -313,82 +323,30 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic > SELECT MOD(2, 1.8); 0.2 """) -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) extends DivModLike { override def inputType: AbstractDataType = NumericType override def symbol: String = "%" override def decimalMethod: String = "remainder" - override def nullable: Boolean = true - private lazy val integral = dataType match { - case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] - case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] + private lazy val mod: (Any, Any) => Any = dataType match { + // special cases to make float/double primitive types faster + case DoubleType => + (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double] + case FloatType => + (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float] + + // catch-all cases + case i: IntegralType => + val integral = i.integral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) + case i: FractionalType => // should only be DecimalType for now + val integral = i.asIntegral.asInstanceOf[Integral[Any]] + (left, right) => integral.rem(left, right) } - override def eval(input: InternalRow): Any = { - val input2 = right.eval(input) - if (input2 == null || input2 == 0) { - null - } else { - val input1 = left.eval(input) - if (input1 == null) { - null - } else { - input1 match { - case d: Double => d % input2.asInstanceOf[java.lang.Double] - case f: Float => f % input2.asInstanceOf[java.lang.Float] - case _ => integral.rem(input1, input2) - } - } - } - } - - /** - * Special case handling for x % 0 ==> null. - */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval1 = left.genCode(ctx) - val eval2 = right.genCode(ctx) - val isZero = if (dataType.isInstanceOf[DecimalType]) { - s"${eval2.value}.isZero()" - } else { - s"${eval2.value} == 0" - } - val javaType = CodeGenerator.javaType(dataType) - val remainder = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" - } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" - } - if (!left.nullable && !right.nullable) { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if ($isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - ${ev.value} = $remainder; - }""") - } else { - ev.copy(code = s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { - ${ev.isNull} = true; - } else { - ${ev.value} = $remainder; - } - }""") - } - } + override def evalOperation(left: Any, right: Any): Any = mod(left, right) } @ExpressionDescription( @@ -479,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -490,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -612,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -687,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4dda525294259..66315e5906253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -330,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -764,6 +765,40 @@ class CodegenContext { """.stripMargin } + /** + * Generates code creating a [[UnsafeArrayData]]. The generated code executes + * a provided fallback when the size of backing array would exceed the array size limit. + * @param arrayName a name of the array to create + * @param numElements a piece of code representing the number of elements the array should contain + * @param elementSize a size of an element in bytes + * @param bodyCode a function generating code that fills up the [[UnsafeArrayData]] + * and getting the backing array as a parameter + * @param fallbackCode a piece of code executed when the array size limit is exceeded + */ + def createUnsafeArrayWithFallback( + arrayName: String, + numElements: String, + elementSize: Int, + bodyCode: String => String, + fallbackCode: String): String = { + val arraySize = freshName("size") + val arrayBytes = freshName("arrayBytes") + s""" + |final long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( + | $numElements, + | $elementSize); + |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | $fallbackCode + |} else { + | final byte[] $arrayBytes = new byte[(int)$arraySize]; + | UnsafeArrayData $arrayName = new UnsafeArrayData(); + | Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, $numElements); + | $arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, (int)$arraySize); + | ${bodyCode(arrayBytes)} + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. @@ -1022,7 +1057,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1050,7 +1085,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1107,7 +1142,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1126,9 +1161,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e129664..3f4704d287cbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e9dbf69..39778661d1c48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c4529bd..8f2a5a0dce943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018488863..250ce48d059e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -114,6 +115,147 @@ object JavaCode { } } +/** + * A trait representing a block of java code. + */ +trait Block extends JavaCode { + + // The expressions to be evaluated inside this block. + def exprValues: Set[ExprValue] + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def nonEmpty: Boolean = toString.nonEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + // Concatenates this block with other block. + def + (other: Block): Block +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: Block => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case _ => + buf.append(input) + } + buf.append(strings.next) + } + if (buf.nonEmpty) { + codeParts += buf.toString + } + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override lazy val exprValues: Set[ExprValue] = { + blockInputs.flatMap { + case b: Block => b.exprValues + case e: ExprValue => Set(e) + }.toSet + } + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + +case class Blocks(blocks: Seq[Block]) extends Block { + override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet + override lazy val code: String = blocks.map(_.toString).mkString("\n") + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(blocks :+ c) + case b: Blocks => Blocks(blocks ++ b.blocks) + case EmptyBlock => this + } +} + +object EmptyBlock extends Block with Serializable { + override val code: String = "" + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other +} + /** * A typed java fragment that must be a valid java expression. */ @@ -123,10 +265,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 12b9ab2b272ab..03b3b21a16617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,15 +18,53 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.{ByteArray, UTF8String} +/** + * Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit + * casting. + */ +trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression + with ImplicitCastInputTypes { + + @transient protected lazy val elementType: DataType = + inputTypes.head.asInstanceOf[ArrayType].elementType + + override def inputTypes: Seq[AbstractDataType] = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) => + TypeCoercion.findTightestCommonType(e1, e2) match { + case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2)) + case _ => Seq.empty + } + case _ => Seq.empty + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + (left.dataType, right.dataType) match { + case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + TypeCheckResult.TypeCheckSuccess + case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " + + s"been two ${ArrayType.simpleString}s with same element type, but it's " + + s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]") + } + } +} + + /** * Given an array or map, returns its size. Returns -1 if null. */ @@ -54,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : @@ -118,6 +156,158 @@ case class MapValues(child: Expression) override def prettyName: String = "map_values" } +/** + * Returns an unordered array of all entries in the given map. + */ +@ExpressionDescription( + usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.", + examples = """ + Examples: + > SELECT _FUNC_(map(1, 'a', 2, 'b')); + [(1,"a"),(2,"b")] + """, + since = "2.4.0") +case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType] + + override def dataType: DataType = { + ArrayType( + StructType( + StructField("key", childDataType.keyType, false) :: + StructField("value", childDataType.valueType, childDataType.valueContainsNull) :: + Nil), + false) + } + + override protected def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val length = childMap.numElements() + val resultData = new Array[AnyRef](length) + var i = 0; + while (i < length) { + val key = keys.get(i, childDataType.keyType) + val value = values.get(i, childDataType.valueType) + val row = new GenericInternalRow(Array[Any](key, value)) + resultData.update(i, row) + i += 1 + } + new GenericArrayData(resultData) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numElements = ctx.freshName("numElements") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, keys, values, ev.value, numElements) + } else { + genCodeForAnyElements(ctx, keys, values, ev.value, numElements) + } + s""" + |final int $numElements = $c.numElements(); + |final ArrayData $keys = $c.keyArray(); + |final ArrayData $values = $c.valueArray(); + |$code + """.stripMargin + }) + } + + private def getKey(varName: String) = CodeGenerator.getValue(varName, childDataType.keyType, "z") + + private def getValue(varName: String) = { + CodeGenerator.getValue(varName, childDataType.valueType, "z") + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val unsafeRow = ctx.freshName("unsafeRow") + val unsafeArrayData = ctx.freshName("unsafeArrayData") + val structsOffset = ctx.freshName("structsOffset") + val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes" + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val wordSize = UnsafeRow.WORD_SIZE + val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2 + val structSizeAsLong = structSize + "L" + val keyTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(childDataType.keyType) + + val valueAssignment = s"$unsafeRow.set$valueTypeName(1, ${getValue(values)});" + val valueAssignmentChecked = if (childDataType.valueContainsNull) { + s""" + |if ($values.isNullAt(z)) { + | $unsafeRow.setNullAt(1); + |} else { + | $valueAssignment + |} + """.stripMargin + } else { + valueAssignment + } + + val assignmentLoop = (byteArray: String) => + s""" + |final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize; + |UnsafeRow $unsafeRow = new UnsafeRow(2); + |for (int z = 0; z < $numElements; z++) { + | long offset = $structsOffset + z * $structSizeAsLong; + | $unsafeArrayData.setLong(z, (offset << 32) + $structSizeAsLong); + | $unsafeRow.pointTo($byteArray, $baseOffset + offset, $structSize); + | $unsafeRow.set$keyTypeName(0, ${getKey(keys)}); + | $valueAssignmentChecked + |} + |$arrayData = $unsafeArrayData; + """.stripMargin + + ctx.createUnsafeArrayWithFallback( + unsafeArrayData, + numElements, + structSize + wordSize, + assignmentLoop, + genCodeForAnyElements(ctx, keys, values, arrayData, numElements)) + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + keys: String, + values: String, + arrayData: String, + numElements: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val rowClass = classOf[GenericInternalRow].getName + val data = ctx.freshName("internalRowArray") + + val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType) + val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) { + s"$values.isNullAt(z) ? null : (Object)${getValue(values)}" + } else { + getValue(values) + } + + s""" + |final Object[] $data = new Object[$numElements]; + |for (int z = 0; z < $numElements; z++) { + | $data[z] = new $rowClass(new Object[]{${getKey(keys)}, $getValueWithCheck}); + |} + |$arrayData = new $genericArrayClass($data); + """.stripMargin + } + + override def prettyName: String = "map_entries" +} + /** * Common base class for [[SortArray]] and [[ArraySort]]. */ @@ -468,6 +658,9 @@ case class ArrayContains(left: Expression, right: Expression) override def dataType: DataType = BooleanType + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq.empty case _ => left.dataType match { @@ -484,7 +677,7 @@ case class ArrayContains(left: Expression, right: Expression) TypeCheckResult.TypeCheckFailure( "Arguments must be an array followed by a value of same type as the array members") } else { - TypeCheckResult.TypeCheckSuccess + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") } } @@ -497,7 +690,7 @@ case class ArrayContains(left: Expression, right: Expression) arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => if (v == null) { hasNull = true - } else if (v == value) { + } else if (ordering.equiv(v, value)) { return true } ) @@ -529,6 +722,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Checks if the two arrays contain at least one common element. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5)); + true + """, since = "2.4.0") +// scalastyle:off line.size.limit +case class ArraysOverlap(left: Expression, right: Expression) + extends BinaryArrayExpressionWithImplicitCast { + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName") + case failure => failure + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + @transient private lazy val doEvaluation = if (elementTypeSupportEquals) { + fastEval _ + } else { + bruteForceEval _ + } + + override def dataType: DataType = BooleanType + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + } + + override def nullSafeEval(a1: Any, a2: Any): Any = { + doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData]) + } + + /** + * A fast implementation which puts all the elements from the smaller array in a set + * and then performs a lookup on it for each element of the bigger one. + * This eval mode works only for data types which implements properly the equals method. + */ + private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) { + (arr1, arr2) + } else { + (arr2, arr1) + } + if (smaller.numElements() > 0) { + val smallestSet = new mutable.HashSet[Any] + smaller.foreach(elementType, (_, v) => + if (v == null) { + hasNull = true + } else { + smallestSet += v + }) + bigger.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else if (smallestSet.contains(v1)) { + return true + } + ) + } + if (hasNull) { + null + } else { + false + } + } + + /** + * A slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = { + var hasNull = false + if (arr1.numElements() > 0 && arr2.numElements() > 0) { + arr1.foreach(elementType, (_, v1) => + if (v1 == null) { + hasNull = true + } else { + arr2.foreach(elementType, (_, v2) => + if (v2 == null) { + hasNull = true + } else if (ordering.equiv(v1, v2)) { + return true + } + ) + }) + } + if (hasNull) { + null + } else { + false + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, (a1, a2) => { + val smaller = ctx.freshName("smallerArray") + val bigger = ctx.freshName("biggerArray") + val comparisonCode = if (elementTypeSupportEquals) { + fastCodegen(ctx, ev, smaller, bigger) + } else { + bruteForceCodegen(ctx, ev, smaller, bigger) + } + s""" + |ArrayData $smaller; + |ArrayData $bigger; + |if ($a1.numElements() > $a2.numElements()) { + | $bigger = $a1; + | $smaller = $a2; + |} else { + | $smaller = $a1; + | $bigger = $a2; + |} + |if ($smaller.numElements() > 0) { + | $comparisonCode + |} + """.stripMargin + }) + } + + /** + * Code generation for a fast implementation which puts all the elements from the smaller array + * in a set and then performs a lookup on it for each element of the bigger one. + * It works only for data types which implements properly the equals method. + */ + private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val javaElementClass = CodeGenerator.boxedType(elementType) + val javaSet = classOf[java.util.HashSet[_]].getName + val set = ctx.freshName("set") + val addToSetFromSmallerCode = nullSafeElementCodegen( + smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;") + val elementIsInSetCode = nullSafeElementCodegen( + bigger, + i, + s""" + |if ($set.contains($getFromBigger)) { + | ${ev.isNull} = false; + | ${ev.value} = true; + | break; + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>(); + |for (int $i = 0; $i < $smaller.numElements(); $i ++) { + | $addToSetFromSmallerCode + |} + |for (int $i = 0; $i < $bigger.numElements(); $i ++) { + | $elementIsInSetCode + |} + """.stripMargin + } + + /** + * Code generation for a slower evaluation which performs a nested loop and supports all the data types. + */ + private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = { + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j) + val getFromBigger = CodeGenerator.getValue(bigger, elementType, i) + val compareValues = nullSafeElementCodegen( + smaller, + j, + s""" + |if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) { + | ${ev.isNull} = false; + | ${ev.value} = true; + |} + """.stripMargin, + s"${ev.isNull} = true;") + val isInSmaller = nullSafeElementCodegen( + bigger, + i, + s""" + |for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) { + | $compareValues + |} + """.stripMargin, + s"${ev.isNull} = true;") + s""" + |for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) { + | $isInSmaller + |} + """.stripMargin + } + + def nullSafeElementCodegen( + arrayVar: String, + index: String, + code: String, + isNullCode: String): String = { + if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) { + s""" + |if ($arrayVar.isNullAt($index)) { + | $isNullCode + |} else { + | $code + |} + """.stripMargin + } else { + code + } + } + + override def prettyName: String = "arrays_overlap" +} + /** * Slices an array according to the requested start index and length */ @@ -760,14 +1178,14 @@ case class ArrayJoin( } if (nullable) { ev.copy( - s""" + code""" |boolean ${ev.isNull} = true; |UTF8String ${ev.value} = null; |$code """.stripMargin) } else { ev.copy( - s""" + code""" |UTF8String ${ev.value} = null; |$code """.stripMargin, FalseLiteral) @@ -852,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -917,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -973,13 +1391,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast case class ArrayPosition(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) + override def dataType: DataType = LongType override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType) + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName") + } + } + override def nullSafeEval(arr: Any, value: Any): Any = { arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) { + if (v != null && ordering.equiv(v, value)) { return (i + 1).toLong } ) @@ -1028,6 +1457,9 @@ case class ArrayPosition(left: Expression, right: Expression) since = "2.4.0") case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType) + override def dataType: DataType = left.dataType match { case ArrayType(elementType, _) => elementType case MapType(_, valueType, _) => valueType @@ -1038,10 +1470,21 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti left.dataType match { case _: ArrayType => IntegerType case _: MapType => left.dataType.asInstanceOf[MapType].keyType + case _ => AnyDataType // no match for a wrong 'left' expression type } ) } + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] => + TypeUtils.checkForOrderingExpr( + left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName") + case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess + } + } + override def nullable: Boolean = true override def nullSafeEval(value: Any, ordinal: Any): Any = { @@ -1066,7 +1509,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti } } case _: MapType => - getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType) + getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering) } } @@ -1212,7 +1655,7 @@ case class Concat(children: Seq[Expression]) extends Expression { expressions = inputs, funcName = "valueConcat", extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" $initCode $codes $javaType ${ev.value} = $concatenator.concat($args); @@ -1468,3 +1911,152 @@ case class Flatten(child: Expression) extends UnaryExpression { override def prettyName: String = "flatten" } + +/** + * Returns the array containing the given input value (left) count (right) times. + */ +@ExpressionDescription( + usage = "_FUNC_(element, count) - Returns the array containing element count times.", + examples = """ + Examples: + > SELECT _FUNC_('123', 2); + ['123', '123'] + """, + since = "2.4.0") +case class ArrayRepeat(left: Expression, right: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + override def dataType: ArrayType = ArrayType(left.dataType, left.nullable) + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType) + + override def nullable: Boolean = right.nullable + + override def eval(input: InternalRow): Any = { + val count = right.eval(input) + if (count == null) { + null + } else { + if (count.asInstanceOf[Int] > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to create array with $count elements " + + s"due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + } + val element = left.eval(input) + new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element)) + } + } + + override def prettyName: String = "array_repeat" + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val element = leftGen.value + val count = rightGen.value + val et = dataType.elementType + + val coreLogic = if (CodeGenerator.isPrimitiveType(et)) { + genCodeForPrimitiveElement(ctx, et, element, count, leftGen.isNull, ev.value) + } else { + genCodeForNonPrimitiveElement(ctx, element, count, leftGen.isNull, ev.value) + } + val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) + + ev.copy(code = + code""" + |boolean ${ev.isNull} = false; + |${leftGen.code} + |${rightGen.code} + |${CodeGenerator.javaType(dataType)} ${ev.value} = + | ${CodeGenerator.defaultValue(dataType)}; + |$resultCode + """.stripMargin) + } + + private def nullElementsProtection( + ev: ExprCode, + rightIsNull: String, + coreLogic: String): String = { + if (nullable) { + s""" + |if ($rightIsNull) { + | ${ev.isNull} = true; + |} else { + | ${coreLogic} + |} + """.stripMargin + } else { + coreLogic + } + } + + private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = { + val numElements = ctx.freshName("numElements") + val numElementsCode = + s""" + |int $numElements = 0; + |if ($count > 0) { + | $numElements = $count; + |} + |if ($numElements > $MAX_ARRAY_LENGTH) { + | throw new RuntimeException("Unsuccessful try to create array with " + $numElements + + | " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH."); + |} + """.stripMargin + + (numElements, numElementsCode) + } + + private def genCodeForPrimitiveElement( + ctx: CodegenContext, + elementType: DataType, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val tempArrayDataName = ctx.freshName("tempArrayData") + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val errorMessage = s" $prettyName failed." + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, errorMessage)} + |if (!$leftIsNull) { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.set$primitiveValueTypeName(k, $element); + | } + |} else { + | for (int k = 0; k < $tempArrayDataName.numElements(); k++) { + | $tempArrayDataName.setNullAt(k); + | } + |} + |$arrayDataName = $tempArrayDataName; + """.stripMargin + } + + private def genCodeForNonPrimitiveElement( + ctx: CodegenContext, + element: String, + count: String, + leftIsNull: String, + arrayDataName: String): String = { + val genericArrayClass = classOf[GenericArrayData].getName + val arrayName = ctx.freshName("arrayObject") + val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count) + + s""" + |$numElemCode + |Object[] $arrayName = new Object[(int)$numElemName]; + |if (!$leftIsNull) { + | for (int k = 0; k < $numElemName; k++) { + | $arrayName[k] = $element; + | } + |} + |$arrayDataName = new $genericArrayClass($arrayName); + """.stripMargin + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8565488..a9867aaeb0cfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3fba52d745453..99671d5b863c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { // todo: current search is O(n), improve it. - def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = { + def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = { val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() @@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy var i = 0 var found = false while (i < length && !found) { - if (keys.get(i, keyType) == ordinal) { + if (ordering.equiv(keys.get(i, keyType), ordinal)) { found = true } else { i += 1 @@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy case class GetMapValue(child: Expression, key: Expression) extends GetMapValueUtil with ExtractValue with NullIntolerant { + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + private def keyType = child.dataType.asInstanceOf[MapType].keyType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case f: TypeCheckResult.TypeCheckFailure => f + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName") + } + } + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) @@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression) // todo: current search is O(n), improve it. override def nullSafeEval(value: Any, ordinal: Any): Any = { - getValueEval(value, ordinal, keyType) + getValueEval(value, ordinal, keyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f6a9acf..77ac6c088022e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 76aa61415a11f..e8d85f72f7a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") val eval = child.genCode(ctx) - val code = s""" + val code = code""" |${eval.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; @@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1194,13 +1195,21 @@ case class AddMonths(startDate: Expression, numMonths: Expression) } /** - * Returns number of months between dates date1 and date2. + * Returns number of months between times `timestamp1` and `timestamp2`. + * If `timestamp1` is later than `timestamp2`, then the result is positive. + * If `timestamp1` and `timestamp2` are on the same day of month, or both + * are the last day of month, time of day will be ignored. Otherwise, the + * difference is calculated based on 31 days per month, and rounded to + * 8 digits unless roundOff=false. */ // scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(timestamp1, timestamp2[, roundOff]) - Returns number of months between `timestamp1` and `timestamp2`. - The result is rounded to 8 decimal places by default. Set roundOff=false otherwise."""", + _FUNC_(timestamp1, timestamp2[, roundOff]) - If `timestamp1` is later than `timestamp2`, then the result + is positive. If `timestamp1` and `timestamp2` are on the same day of month, or both + are the last day of month, time of day will be ignored. Otherwise, the difference is + calculated based on 31 days per month, and rounded to 8 digits unless roundOff=false. + """, examples = """ Examples: > SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30'); @@ -1279,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1293,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1436,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579ba28671..04de83343be71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfebad45e..b7c52f1d7b40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef790338bdd27..cec00b66f873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc580273ee..3b0141ad52cc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 34161f0f03f4a..04a4eb0ffc032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -548,7 +548,7 @@ case class JsonToStructs( forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) override def checkInputDataTypes(): TypeCheckResult = nullableSchema match { - case _: StructType | ArrayType(_: StructType, _) => + case _: StructType | ArrayType(_: StructType, _) | _: MapType => super.checkInputDataTypes() case _ => TypeCheckResult.TypeCheckFailure( s"Input schema ${nullableSchema.simpleString} must be a struct or an array of structs.") @@ -558,6 +558,7 @@ case class JsonToStructs( lazy val rowSchema = nullableSchema match { case st: StructType => st case ArrayType(st: StructType, _) => st + case mt: MapType => mt } // This converts parsed rows to the desired output by the given schema. @@ -567,6 +568,8 @@ case class JsonToStructs( (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null case ArrayType(_: StructType, _) => (rows: Seq[InternalRow]) => new GenericArrayData(rows) + case _: MapType => + (rows: Seq[InternalRow]) => rows.head.getMap(0) } @transient @@ -613,6 +616,11 @@ case class JsonToStructs( } override def inputTypes: Seq[AbstractDataType] = StringType :: Nil + + override def sql: String = schema match { + case _: MapType => "entries" + case _ => super.sql + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfcec47425..c2e1720259b53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b7834696cafc3..5d98dac46cf17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342bce6bc..2eeed3bbb2d91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f974fd81fc788..2bf4203d0fec3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -269,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -385,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -492,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -532,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -935,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1709,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f8c6dc4e6adc9..f54103c4fbfba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2653b28f6c3bd..926c2f00d430d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c0791d895f..7b68bb771faf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a26a4c8b..9823b2fc5ad97 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index a5a4a13eb608b..c3a4ca8f64bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * Constructs a parser for a given schema that translates a json string to an [[InternalRow]]. */ class JacksonParser( - schema: StructType, + schema: DataType, val options: JSONOptions) extends Logging { import JacksonUtils._ @@ -57,7 +57,14 @@ class JacksonParser( * to a value according to a desired schema. This is a wrapper for the method * `makeConverter()` to handle a row wrapped with an array. */ - private def makeRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { + private def makeRootConverter(dt: DataType): JsonParser => Seq[InternalRow] = { + dt match { + case st: StructType => makeStructRootConverter(st) + case mt: MapType => makeMapRootConverter(mt) + } + } + + private def makeStructRootConverter(st: StructType): JsonParser => Seq[InternalRow] = { val elementConverter = makeConverter(st) val fieldConverters = st.map(_.dataType).map(makeConverter).toArray (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, st) { @@ -87,6 +94,13 @@ class JacksonParser( } } + private def makeMapRootConverter(mt: MapType): JsonParser => Seq[InternalRow] = { + val fieldConverter = makeConverter(mt.valueType) + (parser: JsonParser) => parseJsonToken[Seq[InternalRow]](parser, mt) { + case START_OBJECT => Seq(InternalRow(convertMap(parser, fieldConverter))) + } + } + /** * Create a converter which converts the JSON documents held by the `JsonParser` * to a value according to a desired schema. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 64eed23884584..b9ece295c2510 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -504,6 +504,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging withJoinRelations(join, relation) } if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } withPivot(ctx.pivotClause, from) } else { ctx.lateralView.asScala.foldLeft(from)(withGenerate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index e646da0659e85..80f15053005ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -885,13 +885,13 @@ object DateTimeUtils { /** * Returns number of months between time1 and time2. time1 and time2 are expressed in - * microseconds since 1.1.1970. + * microseconds since 1.1.1970. If time1 is later than time2, the result is positive. * - * If time1 and time2 having the same day of month, or both are the last day of month, - * it returns an integer (time under a day will be ignored). + * If time1 and time2 are on the same day of month, or both are the last day of month, + * returns, time of day will be ignored. * * Otherwise, the difference is calculated based on 31 days per month. - * If `roundOff` is set to true, the result is rounded to 8 decimal places. + * The result is rounded to 8 decimal places if `roundOff` is set to true. */ def monthsBetween( time1: SQLTimestamp, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..19f67236c8979 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b1965c438e27..d0478d6ad250b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -95,7 +95,9 @@ object SQLConf { /** * Returns the active config object within the current scope. If there is an active SparkSession, - * the proper SQLConf associated with the thread's session is used. + * the proper SQLConf associated with the thread's active session is used. If it's called from + * tasks in the executor side, a SQLConf will be created from job local properties, which are set + * and propagated from the driver side. * * The way this works is a little bit convoluted, due to the fact that config was added initially * only for physical plans (and as a result not in sql/catalyst module). @@ -108,11 +110,20 @@ object SQLConf { * run unit tests (that does not involve SparkSession) in serial order. */ def get: SQLConf = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + if (Utils.isTesting && SparkContext.getActive.isDefined) { + // DAGScheduler event loop thread does not have an active SparkSession, the `confGetter` + // will return `fallbackConf` which is unexpected. Here we prevent it from happening. + val schedulerEventLoopThread = + SparkContext.getActive.get.dagScheduler.eventProcessLoop.eventThread + if (schedulerEventLoopThread.getId == Thread.currentThread().getId) { + throw new RuntimeException("Cannot get SQLConf inside scheduler event loop thread.") + } + } + confGetter.get()() } - confGetter.get()() } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") @@ -1161,8 +1172,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SQL_OPTIONS_REDACTION_PATTERN = + buildConf("spark.sql.redaction.options.regex") + .doc("Regex to decide which keys in a Spark SQL command's options map contain sensitive " + + "information. The values of options whose names that match this regex will be redacted " + + "in the explain output. This redaction is applied on top of the global redaction " + + s"configuration defined by ${SECRET_REDACTION_PATTERN.key}.") + .regexConf + .createWithDefault("(?i)url".r) + val SQL_STRING_REDACTION_PATTERN = - ConfigBuilder("spark.sql.redaction.string.regex") + buildConf("spark.sql.redaction.string.regex") .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + "information. When this regex matches a string part, that string part is replaced by a " + "dummy value. This is currently used to redact the output of SQL explain commands. " + @@ -1259,6 +1279,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val TOP_K_SORT_FALLBACK_THRESHOLD = + buildConf("spark.sql.execution.topKSortFallbackThreshold") + .internal() + .doc("In SQL queries with a SORT followed by a LIMIT like " + + "'SELECT x FROM t ORDER BY y LIMIT m', if m is under this threshold, do a top-K sort" + + " in memory, otherwise do a global sort which spills to disk if necessary.") + .intConf + .createWithDefault(Int.MaxValue) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1266,6 +1295,13 @@ object SQLConf { object Replaced { val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" } + + val CSV_PARSER_COLUMN_PRUNING = buildConf("spark.sql.csv.parser.columnPruning.enabled") + .internal() + .doc("If it is set to true, column names of the requested schema are passed to CSV parser. " + + "Other column values can be ignored during parsing even if they are malformed.") + .booleanConf + .createWithDefault(true) } /** @@ -1284,7 +1320,7 @@ class SQLConf extends Serializable with Logging { @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1420,10 +1456,12 @@ class SQLConf extends Serializable with Logging { def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) - def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader) + def stringRedactionPattern: Option[Regex] = getConf(SQL_STRING_REDACTION_PATTERN) def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. @@ -1727,6 +1765,17 @@ class SQLConf extends Serializable with Logging { }.toSeq } + /** + * Redacts the given option map according to the description of SQL_OPTIONS_REDACTION_PATTERN. + */ + def redactOptions(options: Map[String, String]): Map[String, String] = { + val regexes = Seq( + getConf(SQL_OPTIONS_REDACTION_PATTERN), + SECRET_REDACTION_PATTERN.readFrom(reader)) + + regexes.foldLeft(options.toSeq) { case (opts, r) => Utils.redact(Some(r), opts) }.toMap + } + /** * Return whether a given key is set in this [[SQLConf]]. */ @@ -1734,7 +1783,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4ee12db9c10ca..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true @@ -214,7 +218,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index f73e045685ee1..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType, Boolean) => Option[DataType], + widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) + var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) + found = widenFunc(t2, t1) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion(conf) + val rule = TypeCoercion.FunctionArgumentConversion val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion(conf) + val rule = TypeCoercion.IfCoercion val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion(conf), Division) + val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index a2851d071c7c6..3fc0b08c56e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -56,6 +57,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapValues(m2), null) } + test("MapEntries") { + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys/values + val mi0 = Literal.create(Map(1 -> 1, 2 -> null, 3 -> 2), MapType(IntegerType, IntegerType)) + val mi1 = Literal.create(Map[Int, Int](), MapType(IntegerType, IntegerType)) + val mi2 = Literal.create(null, MapType(IntegerType, IntegerType)) + + checkEvaluation(MapEntries(mi0), Seq(r(1, 1), r(2, null), r(3, 2))) + checkEvaluation(MapEntries(mi1), Seq.empty) + checkEvaluation(MapEntries(mi2), null) + + // Non-primitive-type keys/values + val ms0 = Literal.create(Map("a" -> "c", "b" -> null), MapType(StringType, StringType)) + val ms1 = Literal.create(Map[Int, Int](), MapType(StringType, StringType)) + val ms2 = Literal.create(null, MapType(StringType, StringType)) + + checkEvaluation(MapEntries(ms0), Seq(r("a", "c"), r("b", null))) + checkEvaluation(MapEntries(ms1), Seq.empty) + checkEvaluation(MapEntries(ms2), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) @@ -134,6 +157,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + // binary + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null), + ArrayType(BinaryType)) + val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)), + ArrayType(BinaryType)) + val be = Literal.create(Array[Byte](1, 2), BinaryType) + val nullBinary = Literal.create(null, BinaryType) + + checkEvaluation(ArrayContains(b0, be), true) + checkEvaluation(ArrayContains(b1, be), false) + checkEvaluation(ArrayContains(b0, nullBinary), null) + checkEvaluation(ArrayContains(b2, be), null) + checkEvaluation(ArrayContains(b3, be), true) + + // complex data types + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayContains(aa0, aae), true) + checkEvaluation(ArrayContains(aa1, aae), false) + } + + test("ArraysOverlap") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType)) + val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType)) + val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType)) + val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType)) + + val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType)) + + checkEvaluation(ArraysOverlap(a0, a1), true) + checkEvaluation(ArraysOverlap(a0, a2), null) + checkEvaluation(ArraysOverlap(a1, a2), true) + checkEvaluation(ArraysOverlap(a1, a3), false) + checkEvaluation(ArraysOverlap(a0, emptyIntArray), false) + checkEvaluation(ArraysOverlap(a2, emptyIntArray), false) + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + + checkEvaluation(ArraysOverlap(a4, a5), true) + checkEvaluation(ArraysOverlap(a4, a6), null) + checkEvaluation(ArraysOverlap(a5, a6), false) + + // null handling + checkEvaluation(ArraysOverlap(emptyIntArray, a2), false) + checkEvaluation(ArraysOverlap( + emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false) + checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null) + checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null) + checkEvaluation(ArraysOverlap( + Literal.create(Seq(null), ArrayType(IntegerType)), + Literal.create(Seq(null), ArrayType(IntegerType))), null) + + // arrays of binaries + val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)), + ArrayType(BinaryType)) + val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)), + ArrayType(BinaryType)) + val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)), + ArrayType(BinaryType)) + + checkEvaluation(ArraysOverlap(b0, b1), true) + checkEvaluation(ArraysOverlap(b0, b2), false) + + // arrays of complex data types + val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")), + ArrayType(ArrayType(StringType))) + val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")), + ArrayType(ArrayType(StringType))) + val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")), + ArrayType(ArrayType(StringType))) + + checkEvaluation(ArraysOverlap(aa0, aa1), true) + checkEvaluation(ArraysOverlap(aa0, aa2), false) + + // null handling with complex datatypes + val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType)) + val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType)) + checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false) + checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false) + checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null) + checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null) } test("Slice") { @@ -283,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayPosition(a3, Literal("")), null) checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null) + + val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)), + ArrayType(ArrayType(IntegerType))) + val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)), + ArrayType(ArrayType(IntegerType))) + val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType)) + checkEvaluation(ArrayPosition(aa0, aae), 1L) + checkEvaluation(ArrayPosition(aa1, aae), 0L) } test("elementAt") { @@ -320,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) val m2 = Literal.create(null, MapType(StringType, StringType)) - checkEvaluation(ElementAt(m0, Literal(1.0)), null) + assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure) checkEvaluation(ElementAt(m0, Literal("d")), null) @@ -331,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ElementAt(m0, Literal("c")), null) checkEvaluation(ElementAt(m2, Literal("a")), null) + + // test binary type as keys + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null) } test("Concat") { @@ -468,4 +604,22 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Flatten(asa3), null) checkEvaluation(Flatten(asa4), null) } + + test("ArrayRepeat") { + val intArray = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + val strArray = Literal.create(Seq("hi", "hola"), ArrayType(StringType)) + + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(0)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(-1)), Seq()) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(1)), Seq("hi")) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(2)), Seq("hi", "hi")) + checkEvaluation(ArrayRepeat(Literal(true), Literal(2)), Seq(true, true)) + checkEvaluation(ArrayRepeat(Literal(1), Literal(2)), Seq(1, 1)) + checkEvaluation(ArrayRepeat(Literal(3.2), Literal(2)), Seq(3.2, 3.2)) + checkEvaluation(ArrayRepeat(Literal(null), Literal(2)), Seq[String](null, null)) + checkEvaluation(ArrayRepeat(Literal(null, IntegerType), Literal(2)), Seq[Integer](null, null)) + checkEvaluation(ArrayRepeat(intArray, Literal(2)), Seq(Seq(1, 2), Seq(1, 2))) + checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi", "hola"), Seq("hi", "hola"))) + checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null, IntegerType)), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a22e9d4655e8c..c2a44e0d33b18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -98,6 +98,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { if (expected.isNaN) result.isNaN else expected == result case (result: Float, expected: Float) => if (expected.isNaN) result.isNaN else expected == result + case (result: UnsafeRow, expected: GenericInternalRow) => + val structType = exprDataType.asInstanceOf[StructType] + result.toSeq(structType) == expected.toSeq(structType) case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) case _ => result == expected diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2070ed6..7c7c4cccee253 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala new file mode 100644 index 0000000000000..d2c6420eadb20 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class CodeBlockSuite extends SparkFunSuite { + + test("Block interpolates string and ExprValue inputs") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val stringLiteral = "false" + val code = code"boolean $isNull = $stringLiteral;" + assert(code.toString == "boolean expr1_isNull = false;") + } + + test("Literals are folded into string code parts instead of block inputs") { + val value = JavaCode.variable("expr1", IntegerType) + val intLiteral = 1 + val code = code"int $value = $intLiteral;" + assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) + } + + test("Block.stripMargin") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code1 = + code""" + |boolean $isNull = false; + |int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin + val expected = + s""" + |boolean expr1_isNull = false; + |int expr1 = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim + assert(code1.toString == expected) + + val code2 = + code""" + >boolean $isNull = false; + >int $value = ${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>') + assert(code2.toString == expected) + } + + test("Block can capture input expr values") { + val isNull = JavaCode.isNullVariable("expr1_isNull") + val value = JavaCode.variable("expr1", IntegerType) + val code = + code""" + |boolean $isNull = false; + |int $value = -1; + """.stripMargin + val exprValues = code.exprValues + assert(exprValues.size == 2) + assert(exprValues === Set(value, isNull)) + } + + test("concatenate blocks") { + val isNull1 = JavaCode.isNullVariable("expr1_isNull") + val value1 = JavaCode.variable("expr1", IntegerType) + val isNull2 = JavaCode.isNullVariable("expr2_isNull") + val value2 = JavaCode.variable("expr2", IntegerType) + val literal = JavaCode.literal("100", IntegerType) + + val code = + code""" + |boolean $isNull1 = false; + |int $value1 = -1;""".stripMargin + + code""" + |boolean $isNull2 = true; + |int $value2 = $literal;""".stripMargin + + val expected = + """ + |boolean expr1_isNull = false; + |int expr1 = -1; + |boolean expr2_isNull = true; + |int expr2 = 100;""".stripMargin.trim + + assert(code.toString == expected) + + val exprValues = code.exprValues + assert(exprValues.size == 5) + assert(exprValues === Set(isNull1, value1, isNull2, value2, literal)) + } + + test("Throws exception when interpolating unexcepted object in code block") { + val obj = Tuple2(1, 1) + val e = intercept[IllegalArgumentException] { + code"$obj" + } + assert(e.getMessage().contains(s"Can not interpolate ${obj.getClass.getName}")) + } + + test("replace expr values in code block") { + val expr = JavaCode.expression("1 + 1", IntegerType) + val isNull = JavaCode.isNullVariable("expr1_isNull") + val exprInFunc = JavaCode.variable("expr1", IntegerType) + + val code = + code""" + |callFunc(int $expr) { + | boolean $isNull = false; + | int $exprInFunc = $expr + 1; + |}""".stripMargin + + val aliasedParam = JavaCode.variable("aliased", expr.javaType) + val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map { + case _: SimpleExprValue => aliasedParam + case other => other + } + val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, aliasedInputs).stripMargin + val expected = + code""" + |callFunc(int $aliasedParam) { + | boolean $isNull = false; + | int $exprInFunc = $aliasedParam + 1; + |}""".stripMargin + assert(aliasedCode.toString == expected.toString) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 633d86d495581..5452e72b38647 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select('c as 'sCol2, 'a as 'sCol1) checkRule(originalQuery, correctAnswer) } + + test("SPARK-24313: support binary type as map keys in GetMapValue") { + val mb0 = Literal.create( + Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"), + MapType(BinaryType, StringType)) + val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType)) + + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null) + + checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null) + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2") + checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 812bfdd7bb885..fb51376c6163f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -318,6 +318,16 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", expected) + + intercept( + """select * + |from t + |lateral view explode(x) expl + |pivot ( + | sum(x) + | FOR y IN ('a', 'b') + |)""".stripMargin, + "LATERAL cannot be used together with PIVOT in FROM clause") } test("joins") { diff --git a/sql/core/pom.xml b/sql/core/pom.xml index ef41837f89d68..f270c70fbfcf0 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -38,7 +38,7 @@ com.univocity univocity-parsers - 2.5.9 + 2.6.3 jar diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 10d6ed85a4080..daedfd7e78f5f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet; -import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index aacefacfc1c1a..c62dc3d86386e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -26,7 +26,6 @@ import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; -import org.apache.spark.unsafe.Platform; /** * An implementation of the Parquet PLAIN decoder that supports the vectorized interface. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index 0ea4dc6b5def3..b2526ded53d92 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,7 +30,7 @@ public interface ReadSupport extends DataSourceV2 { /** * Creates a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param options the options for the returned data source reader, which is an immutable diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index 3801402268af1..f31659904cc53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -35,7 +35,7 @@ public interface ReadSupportWithSchema extends DataSourceV2 { /** * Create a {@link DataSourceReader} to scan the data from this data source. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param schema the full schema of this data source reader. Full schema usually maps to the diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java index cab56453816cc..83aeec0c47853 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -35,7 +35,7 @@ public interface WriteSupport extends DataSourceV2 { * Creates an optional {@link DataSourceWriter} to save the data to this data source. Data * sources can return None if there is no writing needed to be done according to the save mode. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param jobId A unique string for the writing job. It's possible that there are many writing diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java index c24f3b21eade1..dcb87715d0b6f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ContinuousInputPartition.java @@ -27,9 +27,9 @@ @InterfaceStability.Evolving public interface ContinuousInputPartition extends InputPartition { /** - * Create a DataReader with particular offset as its startOffset. + * Create an input partition reader with particular offset as its startOffset. * - * @param offset offset want to set as the DataReader's startOffset. + * @param offset offset want to set as the input partition reader's startOffset. */ InputPartitionReader createContinuousReader(PartitionOffset offset); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java index f898c296e4245..36a3e542b5a11 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceReader.java @@ -31,7 +31,7 @@ * {@link ReadSupport#createReader(DataSourceOptions)} or * {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic is delegated to {@link InputPartition}s that are returned by + * logic is delegated to {@link InputPartition}s, which are returned by * {@link #planInputPartitions()}. * * There are mainly 3 kinds of query optimizations: @@ -45,8 +45,8 @@ * only one of them would be respected, according to the priority list from high to low: * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}. * - * If an exception was throw when applying any of these query optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these query optimizations, the action will fail + * and no Spark job will be submitted. * * Spark first applies all operator push-down optimizations that this data source supports. Then * Spark collects information this data source reported for further optimizations. Finally Spark @@ -59,21 +59,21 @@ public interface DataSourceReader { * Returns the actual schema of this data source reader, which may be different from the physical * schema of the underlying storage, as column pruning or other optimizations may happen. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ StructType readSchema(); /** - * Returns a list of read tasks. Each task is responsible for creating a data reader to - * output data for one RDD partition. That means the number of tasks returned here is same as - * the number of RDD partitions this scan outputs. + * Returns a list of {@link InputPartition}s. Each {@link InputPartition} is responsible for + * creating a data reader to output data of one RDD partition. The number of input partitions + * returned here is the same as the number of RDD partitions this scan outputs. * * Note that, this may not be a full scan if the data source reader mixes in other optimization * interfaces like column pruning, filter push-down, etc. These optimizations are applied before * Spark issues the scan request. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ List> planInputPartitions(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java index c581e3b5d0047..f2038d0de3ffe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java @@ -23,28 +23,29 @@ /** * An input partition returned by {@link DataSourceReader#planInputPartitions()} and is - * responsible for creating the actual data reader. The relationship between - * {@link InputPartition} and {@link InputPartitionReader} + * responsible for creating the actual data reader of one RDD partition. + * The relationship between {@link InputPartition} and {@link InputPartitionReader} * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * - * Note that input partitions will be serialized and sent to executors, then the partition reader - * will be created on executors and do the actual reading. So {@link InputPartition} must be - * serializable and {@link InputPartitionReader} doesn't need to be. + * Note that {@link InputPartition}s will be serialized and sent to executors, then + * {@link InputPartitionReader}s will be created on executors to do the actual reading. So + * {@link InputPartition} must be serializable while {@link InputPartitionReader} doesn't need to + * be. */ @InterfaceStability.Evolving public interface InputPartition extends Serializable { /** - * The preferred locations where the data reader returned by this partition can run faster, - * but Spark does not guarantee to run the data reader on these locations. + * The preferred locations where the input partition reader returned by this partition can run + * faster, but Spark does not guarantee to run the input partition reader on these locations. * The implementations should make sure that it can be run on any location. * The location is a string representing the host name. * * Note that if a host name cannot be recognized by Spark, it will be ignored as it was not in - * the returned locations. By default this method returns empty string array, which means this - * task has no location preference. + * the returned locations. The default return value is empty string array, which means this + * input partition's reader has no location preference. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ default String[] preferredLocations() { @@ -52,7 +53,7 @@ default String[] preferredLocations() { } /** - * Returns a data reader to do the actual reading work. + * Returns an input partition reader to do the actual reading work. * * If this method fails (by throwing an exception), the corresponding Spark task would fail and * get retried until hitting the maximum retry times. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java index 1b7051f1ad0af..33fa7be4c1b20 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartitionReader.java @@ -23,12 +23,12 @@ import org.apache.spark.annotation.InterfaceStability; /** - * A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for - * outputting data for a RDD partition. + * An input partition reader returned by {@link InputPartition#createPartitionReader()} and is + * responsible for outputting data for a RDD partition. * - * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data - * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source - * readers that mix in {@link SupportsScanUnsafeRow}. + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal input + * partition readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for input + * partition readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface InputPartitionReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java index d2ee9518d628f..5e32ba6952e1c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java @@ -22,7 +22,8 @@ /** * An interface to represent data distribution requirement, which specifies how the records should - * be distributed among the data partitions(one {@link InputPartitionReader} outputs data for one partition). + * be distributed among the data partitions (one {@link InputPartitionReader} outputs data for one + * partition). * Note that this interface has nothing to do with the data ordering inside one * partition(the output records of a single {@link InputPartitionReader}). * diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java index 716c5c0e9e15a..6e960bedf8020 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousReader.java @@ -35,8 +35,8 @@ @InterfaceStability.Evolving public interface ContinuousReader extends BaseStreamingSource, DataSourceReader { /** - * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances for each - * partition to a single global offset. + * Merge partitioned offsets coming from {@link ContinuousInputPartitionReader} instances + * for each partition to a single global offset. */ Offset mergeOffsets(PartitionOffset[] offsets); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 0a0fd8db58035..0030a9f05dba7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -34,8 +34,8 @@ * It can mix in various writing optimization interfaces to speed up the data saving. The actual * writing logic is delegated to {@link DataWriter}. * - * If an exception was throw when applying any of these writing optimizations, the action would fail - * and no Spark job was submitted. + * If an exception was throw when applying any of these writing optimizations, the action will fail + * and no Spark job will be submitted. * * The writing procedure is: * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the @@ -58,7 +58,7 @@ public interface DataSourceWriter { /** * Creates a writer factory which will be serialized and sent to executors. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. */ DataWriterFactory createWriterFactory(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java index c2c2ab73257e8..7527bcc0c4027 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -35,7 +35,7 @@ public interface DataWriterFactory extends Serializable { /** * Returns a data writer to do the actual writing work. * - * If this method fails (by throwing an exception), the action would fail and no Spark job was + * If this method fails (by throwing an exception), the action will fail and no Spark job will be * submitted. * * @param partitionId A unique id of the RDD partition that the returned writer will process. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 53f44888ebaff..917f0cb221412 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -257,7 +257,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. "fetchsize" can be used to control the - * number of rows per fetch. + * number of rows per fetch and "queryTimeout" can be used to wait + * for a Statement object to execute to the given number of seconds. * @since 1.4.0 */ def jdbc( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d518e07bfb62c..32267eb0300f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -511,6 +511,16 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if the `Dataset` is empty. + * + * @group basic + * @since 2.4.0 + */ + def isEmpty: Boolean = withAction("isEmpty", limit(1).groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) == 0 + } + /** * Returns true if this Dataset contains one or more sources that continuously * return data as it arrives. A Dataset that reads data from a streaming source @@ -1607,7 +1617,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index fc3dbc1c5591b..48abad9078650 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} @@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport { } val valueVar = ctx.freshName("value") val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]" - val code = s"${ctx.registerComment(str)}\n" + (if (nullable) { - s""" + val code = code"${ctx.registerComment(str)}" + (if (nullable) { + code""" boolean $isNullVar = $columnVar.isNullAt($ordinal); $javaType $valueVar = $isNullVar ? ${CodeGenerator.defaultValue(dataType)} : ($value); """ } else { - s"$javaType $valueVar = $value;" - }).trim + code"$javaType $valueVar = $value;" + }) ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3d..61c14fee09337 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { * Shorthand for calling redactString() without specifying redacting rules */ private def redact(text: String): String = { - Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text) + Utils.redact(sqlContext.sessionState.conf.stringRedactionPattern, text) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index e4812f3d338fb..5b4edf5136e3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -152,7 +153,7 @@ case class ExpandExec( } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val code = s""" + val code = code""" |boolean $isNull = true; |${CodeGenerator.javaType(firstExpr.dataType)} $value = | ${CodeGenerator.defaultValue(firstExpr.dataType)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f40c50df74ccb..2549b9e1537a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types._ @@ -313,13 +314,13 @@ case class GenerateExec( if (checks.nonEmpty) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |boolean $isNull = ${checks.mkString(" || ")}; |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : $getter; """.stripMargin ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, dt)) } else { - ExprCode(s"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) + ExprCode(code"$javaType $value = $getter;", FalseLiteral, JavaCode.variable(value, dt)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 15379a0663f7d..3112b306c365e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -225,7 +225,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * Redact the sensitive information in the given string. */ private def withRedaction(message: String): String = { - Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message) + Utils.redact(sparkSession.sessionState.conf.stringRedactionPattern, message) } /** A special namespace for commands that can be used to debug query execution. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..439932b0cc3ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,18 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,13 +92,41 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + /** + * Wrap an action with specified SQL configs. These configs will be propagated to the executor + * side via job local properties. + */ + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) + } + try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 82b4eb9fba242..b97a87a122406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -66,9 +66,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object SpecialLimits extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case Limit(IntegerLiteral(limit), child) => // With whole stage codegen, Spark releases resources only when all the output data of the @@ -79,9 +81,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil case other => planLater(other) :: Nil } - case Limit(IntegerLiteral(limit), Sort(order, true, child)) => + case Limit(IntegerLiteral(limit), Sort(order, true, child)) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil - case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) => + case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) + if limit < conf.topKSortFallbackThreshold => TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil case _ => Nil } @@ -361,7 +365,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( - "Stream stream joins without equality predicate is not supported", plan = Some(plan)) + "Stream-stream join without equality predicate is not supported", plan = Some(plan)) case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 828b51fa199de..372dc3db36ce6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan { ctx.INPUT_ROW = row ctx.currentVars = colVars val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false) - val code = s""" + val code = code""" |$evaluateInputs - |${ev.code.trim} - """.stripMargin.trim + |${ev.code} + """.stripMargin ExprCode(code, FalseLiteral, ev.value) } else { // There are no columns @@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan { * them to be evaluated twice. */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { - val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") - variables.foreach(_.code = "") + val evaluate = variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n") + variables.foreach(_.code = EmptyBlock) evaluate } @@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan { val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars.append(ev.code.trim + "\n") - ev.code = "" + evaluateVars.append(ev.code.toString + "\n") + ev.code = EmptyBlock } } evaluateVars.toString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 6a8ec4f722aea..8c7b2c187cccd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -190,7 +191,7 @@ case class HashAggregateExec( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") // The initial expression should not access any column val ev = e.genCode(ctx) - val initVars = s""" + val initVars = code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin @@ -773,8 +774,8 @@ case class HashAggregateExec( val findOrInsertRegularHashMap: String = s""" |// generate grouping key - |${unsafeRowKeyCode.code.trim} - |${hashEval.code.trim} + |${unsafeRowKeyCode.code} + |${hashEval.code} |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index de2d630de3fdb..e1c85823259b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ abstract class HashMapGenerator( val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), "bufValue") val ev = e.genCode(ctx) val initVars = - s""" + code""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 22b63513548fe..66888fce7f9f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -133,6 +133,14 @@ private[arrow] abstract class ArrowFieldWriter { valueVector match { case fixedWidthVector: BaseFixedWidthVector => fixedWidthVector.reset() case variableWidthVector: BaseVariableWidthVector => variableWidthVector.reset() + case listVector: ListVector => + // Manual "reset" the underlying buffer. + // TODO: When we upgrade to Arrow 0.10.0, we can simply remove this and call + // `listVector.reset()`. + val buffers = listVector.getBuffers(false) + buffers.foreach(buf => buf.setZero(0, buf.capacity())) + listVector.setValueCount(0) + listVector.setLastSet(0) case _ => } count = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..9434ceb7cd16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -345,6 +345,20 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output + override def outputOrdering: Seq[SortOrder] = range.outputOrdering + + override def outputPartitioning: Partitioning = { + if (numElements > 0) { + if (numSlices == 1) { + SinglePartition + } else { + RangePartitioning(outputOrdering, numSlices) + } + } else { + UnknownPartitioning(0) + } + } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -629,7 +643,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1edf27619ad7b..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -522,8 +521,6 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => - TypeCoercion.findWiderTypeForTwo( - t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 568e953a5db66..00b1b5dedb593 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.SparkEnv import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.CreatableRelationProvider -import org.apache.spark.util.Utils /** * Saves the results of `query` in to a data source. @@ -50,7 +49,7 @@ case class SaveIntoDataSourceCommand( } override def simpleString: String = { - val redacted = Utils.redact(SparkEnv.get.conf, options.toSeq).toMap + val redacted = SQLConf.get.redactOptions(options) s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index ed2dc65a47914..dd41aee0f2ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -25,6 +25,7 @@ import org.apache.commons.lang3.time.FastDateFormat import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf class CSVOptions( @transient val parameters: CaseInsensitiveMap[String], @@ -80,6 +81,8 @@ class CSVOptions( } } + private[csv] val columnPruning = SQLConf.get.getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) + val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) val parseMode: ParseMode = @@ -164,7 +167,7 @@ class CSVOptions( writerSettings.setIgnoreLeadingWhitespaces(ignoreLeadingWhiteSpaceFlagInWrite) writerSettings.setIgnoreTrailingWhitespaces(ignoreTrailingWhiteSpaceFlagInWrite) writerSettings.setNullValue(nullValue) - writerSettings.setEmptyValue(nullValue) + writerSettings.setEmptyValue("\"\"") writerSettings.setSkipEmptyLines(true) writerSettings.setQuoteAllFields(quoteAll) writerSettings.setQuoteEscapingEnabled(escapeQuotes) @@ -185,6 +188,7 @@ class CSVOptions( settings.setInputBufferSize(inputBufferSize) settings.setMaxColumns(maxColumns) settings.setNullValue(nullValue) + settings.setEmptyValue("") settings.setMaxCharsPerColumn(maxCharsPerColumn) settings.setUnescapedQuoteHandling(UnescapedQuoteHandling.STOP_AT_DELIMITER) settings diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 99557a1ceb0c8..4f00cc5eb3f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -34,10 +34,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class UnivocityParser( - schema: StructType, + dataSchema: StructType, requiredSchema: StructType, val options: CSVOptions) extends Logging { - require(requiredSchema.toSet.subsetOf(schema.toSet), + require(requiredSchema.toSet.subsetOf(dataSchema.toSet), "requiredSchema should be the subset of schema.") def this(schema: StructType, options: CSVOptions) = this(schema, schema, options) @@ -45,9 +45,17 @@ class UnivocityParser( // A `ValueConverter` is responsible for converting the given value to a desired type. private type ValueConverter = String => Any - private val tokenizer = new CsvParser(options.asParserSettings) + private val tokenizer = { + val parserSetting = options.asParserSettings + if (options.columnPruning && requiredSchema.length < dataSchema.length) { + val tokenIndexArr = requiredSchema.map(f => java.lang.Integer.valueOf(dataSchema.indexOf(f))) + parserSetting.selectIndexes(tokenIndexArr: _*) + } + new CsvParser(parserSetting) + } + private val schema = if (options.columnPruning) requiredSchema else dataSchema - private val row = new GenericInternalRow(requiredSchema.length) + private val row = new GenericInternalRow(schema.length) // Retrieve the raw record string. private def getCurrentInput: UTF8String = { @@ -73,11 +81,8 @@ class UnivocityParser( // Each input token is placed in each output row's position by mapping these. In this case, // // output row - ["A", 2] - private val valueConverters: Array[ValueConverter] = + private val valueConverters: Array[ValueConverter] = { schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - - private val tokenIndexArr: Array[Int] = { - requiredSchema.map(f => schema.indexOf(f)).toArray } /** @@ -210,9 +215,8 @@ class UnivocityParser( } else { try { var i = 0 - while (i < requiredSchema.length) { - val from = tokenIndexArr(i) - row(i) = valueConverters(from).apply(tokens(from)) + while (i < schema.length) { + row(i) = valueConverters(i).apply(tokens(i)) i += 1 } row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index b4e5d169066d9..a73a97c06fe5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -89,6 +89,10 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.get(JDBC_NUM_PARTITIONS).map(_.toInt) + // the number of seconds the driver will wait for a Statement object to execute to the given + // number of seconds. Zero means there is no limit. + val queryTimeout = parameters.getOrElse(JDBC_QUERY_TIMEOUT, "0").toInt + // ------------------------------------------------------------ // Optional parameters only for reading // ------------------------------------------------------------ @@ -160,6 +164,7 @@ object JDBCOptions { val JDBC_LOWER_BOUND = newOption("lowerBound") val JDBC_UPPER_BOUND = newOption("upperBound") val JDBC_NUM_PARTITIONS = newOption("numPartitions") + val JDBC_QUERY_TIMEOUT = newOption("queryTimeout") val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 05326210f3242..0bab3689e5d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -57,6 +57,7 @@ object JDBCRDD extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(table)) try { + statement.setQueryTimeout(options.queryTimeout) val rs = statement.executeQuery() try { JdbcUtils.getSchema(rs, dialect, alwaysNullable = true) @@ -281,6 +282,7 @@ private[jdbc] class JDBCRDD( val statement = conn.prepareStatement(sql) logInfo(s"Executing sessionInitStatement: $sql") try { + statement.setQueryTimeout(options.queryTimeout) statement.execute() } finally { statement.close() @@ -298,6 +300,7 @@ private[jdbc] class JDBCRDD( stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) + stmt.setQueryTimeout(options.queryTimeout) rs = stmt.executeQuery() val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index cc506e51bd0c6..f8c5677ea0f2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -73,7 +73,7 @@ class JdbcRelationProvider extends CreatableRelationProvider saveTable(df, tableSchema, isCaseSensitive, options) } else { // Otherwise, do not truncate the table, instead drop and recreate it - dropTable(conn, options.table) + dropTable(conn, options.table, options) createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index e6dc2fda4eb1b..433443007cfd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -76,6 +76,7 @@ object JdbcUtils extends Logging { Try { val statement = conn.prepareStatement(dialect.getTableExistsQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) statement.executeQuery() } finally { statement.close() @@ -86,9 +87,10 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, table: String, options: JDBCOptions): Unit = { val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(s"DROP TABLE $table") } finally { statement.close() @@ -102,6 +104,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(options.url) val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(dialect.getTruncateQuery(options.table)) } finally { statement.close() @@ -254,6 +257,7 @@ object JdbcUtils extends Logging { try { val statement = conn.prepareStatement(dialect.getSchemaQuery(options.table)) try { + statement.setQueryTimeout(options.queryTimeout) Some(getSchema(statement.executeQuery(), dialect)) } catch { case _: SQLException => None @@ -596,7 +600,8 @@ object JdbcUtils extends Logging { insertStmt: String, batchSize: Int, dialect: JdbcDialect, - isolationLevel: Int): Iterator[Byte] = { + isolationLevel: Int, + options: JDBCOptions): Iterator[Byte] = { val conn = getConnection() var committed = false @@ -637,6 +642,9 @@ object JdbcUtils extends Logging { try { var rowCount = 0 + + stmt.setQueryTimeout(options.queryTimeout) + while (iterator.hasNext) { val row = iterator.next() var i = 0 @@ -819,7 +827,8 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition(iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel) + getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, + options) ) } @@ -841,6 +850,7 @@ object JdbcUtils extends Logging { val sql = s"CREATE TABLE $table ($strSchema) $createTableOptions" val statement = conn.createStatement try { + statement.setQueryTimeout(options.queryTimeout) statement.executeUpdate(sql) } finally { statement.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -104,22 +105,19 @@ object TextInputJsonDataSource extends JsonDataSource { CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) + } } private def createBaseDataset( sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -165,7 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e0424b7478122..e7eed95a560a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,17 +44,17 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val caseSensitive = SQLConf.get.caseSensitiveAnalysis - // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions, caseSensitive)) + Some(inferField(parser, configOptions)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -68,9 +67,13 @@ private[sql] object JsonInferSchema { s"Parse Mode: ${FailFastMode.name}.", e) } } - } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -100,15 +103,14 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField( - parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions, caseSensitive) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -125,7 +127,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions, caseSensitive), + inferField(parser, configOptions), nullable = true) } val fields: Array[StructField] = builder.result() @@ -140,7 +142,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) + elementType, inferField(parser, configOptions)) } ArrayType(elementType) @@ -246,14 +248,13 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode, - caseSensitive: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -263,7 +264,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) + case (ty1, ty2) => compatibleType(ty1, ty2) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -271,8 +272,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { - TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -307,8 +308,7 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType( - fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -331,17 +331,15 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType( - compatibleType(elementType1, elementType2, caseSensitive), - containsNull1 || containsNull2) + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2, caseSensitive) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2), caseSensitive) + compatibleType(t1, DecimalType.forType(t2)) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0dea767840ed3..cab00251622b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,7 +61,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { case _: ClassNotFoundException => u case e: Exception => // the provider is valid, but failed to create a logical plan - u.failAnalysis(e.getMessage) + u.failAnalysis(e.getMessage, e) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 1a6b32429313a..8d6fb3820d420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -29,12 +29,12 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: I class DataSourceRDD[T: ClassTag]( sc: SparkContext, - @transient private val readerFactories: Seq[InputPartition[T]]) + @transient private val inputPartitions: Seq[InputPartition[T]]) extends RDD[T](sc, Nil) { override protected def getPartitions: Array[Partition] = { - readerFactories.zipWithIndex.map { - case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory) + inputPartitions.zipWithIndex.map { + case (inputPartition, index) => new DataSourceRDDPartition(index, inputPartition) }.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 9293d4f831bff..e894f8afd6762 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -23,17 +23,10 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project import org.apache.spark.sql.catalyst.rules.Rule object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { - override def apply( - plan: LogicalPlan): LogicalPlan = plan transformUp { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { // PhysicalOperation guarantees that filters are deterministic; no need to check - case PhysicalOperation(project, newFilters, relation : DataSourceV2Relation) => - // merge the filters - val filters = relation.filters match { - case Some(existing) => - existing ++ newFilters - case _ => - newFilters - } + case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => + assert(relation.filters.isEmpty, "data source v2 should do push down only once.") val projectAttrs = project.map(_.toAttribute) val projectSet = AttributeSet(project.flatMap(_.references)) @@ -67,5 +60,7 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] { } else { filtered } + + case other => other.mapChildren(apply) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 6fa716d9fadee..0da0e8610c392 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} @@ -183,7 +184,7 @@ case class BroadcastHashJoinExec( val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") val javaType = CodeGenerator.javaType(a.dataType) - val code = s""" + val code = code""" |boolean $isNull = true; |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)}; |if ($matched != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8261f0f33b61..f4b9d132122e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -521,7 +522,7 @@ case class SortMergeJoinExec( if (a.nullable) { val isNull = ctx.freshName("isNull") val code = - s""" + code""" |$isNull = $leftRow.isNullAt($i); |$value = $isNull ? $defaultValue : ($valueCode); """.stripMargin @@ -533,7 +534,7 @@ case class SortMergeJoinExec( (ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, a.dataType)), leftVarsDecl) } else { - val code = s"$value = $valueCode;" + val code = code"$value = $valueCode;" val leftVarsDecl = s"""$javaType $value = $defaultValue;""" (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), leftVarsDecl) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 80769d728b8f1..8e82cccbc8fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -97,6 +97,18 @@ case class FlatMapGroupsWithStateExec( override def keyExpressions: Seq[Attribute] = groupingAttributes + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + timeoutConf match { + case ProcessingTimeTimeout => + true // Always run batches to process timeouts + case EventTimeTimeout => + // Process another non-data batch only if the watermark has changed in this executed plan + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + case _ => + false + } + } + override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -126,7 +138,6 @@ case class FlatMapGroupsWithStateExec( case _ => iter } - // Generate a iterator that returns the rows grouped by the grouping function // Note that this code ensures that the filtering for timeout occurs only after // all the data has been processed. This is to ensure that the timeout information of all @@ -194,11 +205,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => + val timingOutPairs = store.getRange(None, None).filter { rowPair => val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => + timingOutPairs.flatMap { rowPair => callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) } } else Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java index 80aa5505db991..43ad4b3384ec3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java @@ -19,8 +19,8 @@ /** * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.Offset class, which is the one that will be supported - * in the long term. + * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be + * supported in the long term. * * This class will be removed in a future release. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index fa7c8ee906ecd..afa664eb76525 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -187,6 +187,17 @@ case class StreamingSymmetricHashJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } + override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { + val watermarkUsedForStateCleanup = + stateWatermarkPredicates.left.nonEmpty || stateWatermarkPredicates.right.nonEmpty + + // Latest watermark value is more than that used in this previous executed plan + val watermarkHasChanged = + eventTimeWatermark.isDefined && newMetadata.batchWatermarkMs > eventTimeWatermark.get + + watermarkUsedForStateCleanup && watermarkHasChanged + } + protected override def doExecute(): RDD[InternalRow] = { val stateStoreCoord = sqlContext.sessionState.streamingQueryManager.stateStoreCoordinator val stateStoreNames = SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide) @@ -319,8 +330,7 @@ case class StreamingSymmetricHashJoinExec( // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal // needs to be done greedily by immediately consuming the returned iterator. val cleanupIter = joinType match { - case Inner => - leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case Inner => leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() case LeftOuter => rightSideJoiner.removeOldState() case RightOuter => leftSideJoiner.removeOldState() case _ => throwBadJoinTypeException() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f58146ac42398..0e7d1019b9c8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -122,16 +122,7 @@ class ContinuousExecution( s"Batch $latestEpochId was committed without end epoch offsets!") } committedOffsets = nextOffsets.toStreamProgress(sources) - - // Get to an epoch ID that has definitely never been sent to a sink before. Since sink - // commit happens between offset log write and commit log write, this means an epoch ID - // which is not in the offset log. - val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse { - throw new IllegalStateException( - s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" + - s"an element.") - } - currentBatchId = latestOffsetEpoch + 1 + currentBatchId = latestEpochId + 1 logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets") nextOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala index d8645576c2052..f38577b6a9f16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala @@ -46,8 +46,6 @@ class ContinuousQueuedDataReader( // Important sequencing - we must get our starting point before the provider threads start running private var currentOffset: PartitionOffset = ContinuousDataSourceRDD.getContinuousReader(reader).getOffset - private var currentEpoch: Long = - context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong /** * The record types in the read buffer. @@ -115,8 +113,7 @@ class ContinuousQueuedDataReader( currentEntry match { case EpochMarker => epochCoordEndpoint.send(ReportPartitionOffset( - context.partitionId(), currentEpoch, currentOffset)) - currentEpoch += 1 + context.partitionId(), EpochTracker.getCurrentEpoch.get, currentOffset)) null case ContinuousRow(row, offset) => currentOffset = offset @@ -184,7 +181,7 @@ class ContinuousQueuedDataReader( private val epochCoordEndpoint = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That + // Note that this is *not* the same as the currentEpoch in [[ContinuousWriteRDD]]! That // field represents the epoch wrt the data being processed. The currentEpoch here is just a // counter to ensure we send the appropriate number of markers if we fall behind the driver. private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 8d25d9ccc43d3..516a563bdcc7a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -85,7 +85,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) val start = partitionStartMap(i) // Have each partition advance by numPartitions each row, with starting points staggered // by their partition index. - RateStreamContinuousDataReaderFactory( + RateStreamContinuousInputPartition( start.value, start.runTimeMs, i, @@ -113,7 +113,7 @@ class RateStreamContinuousReader(options: DataSourceOptions) } -case class RateStreamContinuousDataReaderFactory( +case class RateStreamContinuousInputPartition( startValue: Long, startTimeMs: Long, partitionIndex: Int, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala index 91f1576581511..ef5f0da1e7cc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala @@ -45,7 +45,8 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor val epochCoordinator = EpochCoordinatorRef.get( context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get) - var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong + EpochTracker.initializeCurrentEpoch( + context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) while (!context.isInterrupted() && !context.isCompleted()) { var dataWriter: DataWriter[InternalRow] = null @@ -54,19 +55,24 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor try { val dataIterator = prev.compute(split, context) dataWriter = writeTask.createDataWriter( - context.partitionId(), context.attemptNumber(), currentEpoch) + context.partitionId(), + context.attemptNumber(), + EpochTracker.getCurrentEpoch.get) while (dataIterator.hasNext) { dataWriter.write(dataIterator.next()) } logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch is committing.") + s"in epoch ${EpochTracker.getCurrentEpoch.get} is committing.") val msg = dataWriter.commit() epochCoordinator.send( - CommitPartitionEpoch(context.partitionId(), currentEpoch, msg) + CommitPartitionEpoch( + context.partitionId(), + EpochTracker.getCurrentEpoch.get, + msg) ) logInfo(s"Writer for partition ${context.partitionId()} " + - s"in epoch $currentEpoch committed.") - currentEpoch += 1 + s"in epoch ${EpochTracker.getCurrentEpoch.get} committed.") + EpochTracker.incrementCurrentEpoch() } catch { case _: InterruptedException => // Continuous shutdown always involves an interrupt. Just finish the task. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index cc6808065c0cd..8877ebeb26735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -137,30 +137,71 @@ private[continuous] class EpochCoordinator( private val partitionOffsets = mutable.Map[(Long, Int), PartitionOffset]() + private var lastCommittedEpoch = startEpoch - 1 + // Remembers epochs that have to wait for previous epochs to be committed first. + private val epochsWaitingToBeCommitted = mutable.HashSet.empty[Long] + private def resolveCommitsAtEpoch(epoch: Long) = { - val thisEpochCommits = - partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + val thisEpochCommits = findPartitionCommitsForEpoch(epoch) val nextEpochOffsets = partitionOffsets.collect { case ((e, _), o) if e == epoch => o } if (thisEpochCommits.size == numWriterPartitions && nextEpochOffsets.size == numReaderPartitions) { - logDebug(s"Epoch $epoch has received commits from all partitions. Committing globally.") - // Sequencing is important here. We must commit to the writer before recording the commit - // in the query, or we will end up dropping the commit if we restart in the middle. - writer.commit(epoch, thisEpochCommits.toArray) - query.commit(epoch) - - // Cleanup state from before this epoch, now that we know all partitions are forever past it. - for (k <- partitionCommits.keys.filter { case (e, _) => e < epoch }) { - partitionCommits.remove(k) - } - for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) { - partitionOffsets.remove(k) + + // Check that last committed epoch is the previous one for sequencing of committed epochs. + // If not, add the epoch being currently processed to epochs waiting to be committed, + // otherwise commit it. + if (lastCommittedEpoch != epoch - 1) { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is waiting for epoch ${epoch - 1} to be committed first.") + epochsWaitingToBeCommitted.add(epoch) + } else { + commitEpoch(epoch, thisEpochCommits) + lastCommittedEpoch = epoch + + // Commit subsequent epochs that are waiting to be committed. + var nextEpoch = lastCommittedEpoch + 1 + while (epochsWaitingToBeCommitted.contains(nextEpoch)) { + val nextEpochCommits = findPartitionCommitsForEpoch(nextEpoch) + commitEpoch(nextEpoch, nextEpochCommits) + + epochsWaitingToBeCommitted.remove(nextEpoch) + lastCommittedEpoch = nextEpoch + nextEpoch += 1 + } + + // Cleanup state from before last committed epoch, + // now that we know all partitions are forever past it. + for (k <- partitionCommits.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionCommits.remove(k) + } + for (k <- partitionOffsets.keys.filter { case (e, _) => e < lastCommittedEpoch }) { + partitionOffsets.remove(k) + } } } } + /** + * Collect per-partition commits for an epoch. + */ + private def findPartitionCommitsForEpoch(epoch: Long): Iterable[WriterCommitMessage] = { + partitionCommits.collect { case ((e, _), msg) if e == epoch => msg } + } + + /** + * Commit epoch to the offset log. + */ + private def commitEpoch(epoch: Long, messages: Iterable[WriterCommitMessage]): Unit = { + logDebug(s"Epoch $epoch has received commits from all partitions " + + s"and is ready to be committed. Committing epoch $epoch.") + // Sequencing is important here. We must commit to the writer before recording the commit + // in the query, or we will end up dropping the commit if we restart in the middle. + writer.commit(epoch, messages.toArray) + query.commit(epoch) + } + override def receive: PartialFunction[Any, Unit] = { // If we just drop these messages, we won't do any writes to the query. The lame duck tasks // won't shed errors or anything. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala new file mode 100644 index 0000000000000..bc0ae428d4521 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous + +import java.util.concurrent.atomic.AtomicLong + +/** + * Tracks the current continuous processing epoch within a task. Call + * EpochTracker.getCurrentEpoch to get the current epoch. + */ +object EpochTracker { + // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will + // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. + private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { + override def initialValue() = new AtomicLong(-1) + } + + /** + * Get the current epoch for the current task, or None if the task has no current epoch. + */ + def getCurrentEpoch: Option[Long] = { + currentEpoch.get().get() match { + case n if n < 0 => None + case e => Some(e) + } + } + + /** + * Increment the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * between epochs. + */ + def incrementCurrentEpoch(): Unit = { + currentEpoch.get().incrementAndGet() + } + + /** + * Initialize the current epoch for this task thread. Should be called by [[ContinuousWriteRDD]] + * at the beginning of a task. + */ + def initializeCurrentEpoch(startEpoch: Long): Unit = { + currentEpoch.get().set(startEpoch) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala new file mode 100644 index 0000000000000..270b1a5c28dee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.UUID + +import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.NextIterator + +case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends Partition { + // Initialized only on the executor, and only once even as we call compute() multiple times. + lazy val (reader: ContinuousShuffleReader, endpoint) = { + val env = SparkEnv.get.rpcEnv + val receiver = new UnsafeRowReceiver(queueSize, env) + val endpoint = env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver) + TaskContext.get().addTaskCompletionListener { ctx => + env.stop(endpoint) + } + (receiver, endpoint) + } +} + +/** + * RDD at the map side of each continuous processing shuffle task. Upstream tasks send their + * shuffle output to the wrapped receivers in partitions of this RDD; each of the RDD's tasks + * poll from their receiver until an epoch marker is sent. + */ +class ContinuousShuffleReadRDD( + sc: SparkContext, + numPartitions: Int, + queueSize: Int = 1024) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + (0 until numPartitions).map { partIndex => + ContinuousShuffleReadPartition(partIndex, queueSize) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + split.asInstanceOf[ContinuousShuffleReadPartition].reader.read() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala new file mode 100644 index 0000000000000..42631c90ebc55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +/** + * Trait for reading from a continuous processing shuffle. + */ +trait ContinuousShuffleReader { + /** + * Returns an iterator over the incoming rows in an epoch. Implementations should block waiting + * for new rows to arrive, and end the iterator once they've received epoch markers from all + * shuffle writers. + */ + def read(): Iterator[UnsafeRow] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala new file mode 100644 index 0000000000000..b8adbb743c6c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.NextIterator + +/** + * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an epoch marker. + */ +private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable +private[shuffle] case class ReceiverRow(row: UnsafeRow) extends UnsafeRowReceiverMessage +private[shuffle] case class ReceiverEpochMarker() extends UnsafeRowReceiverMessage + +/** + * RPC endpoint for receiving rows into a continuous processing shuffle task. Continuous shuffle + * writers will send rows here, with continuous shuffle readers polling for new rows as needed. + * + * TODO: Support multiple source tasks. We need to output a single epoch marker once all + * source tasks have sent one. + */ +private[shuffle] class UnsafeRowReceiver( + queueSize: Int, + override val rpcEnv: RpcEnv) + extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging { + // Note that this queue will be drained from the main task thread and populated in the RPC + // response thread. + private val queue = new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize) + + // Exposed for testing to determine if the endpoint gets stopped on task end. + private[shuffle] val stopped = new AtomicBoolean(false) + + override def onStop(): Unit = { + stopped.set(true) + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case r: UnsafeRowReceiverMessage => + queue.put(r) + context.reply(()) + } + + override def read(): Iterator[UnsafeRow] = { + new NextIterator[UnsafeRow] { + override def getNext(): UnsafeRow = queue.take() match { + case ReceiverRow(r) => r + case ReceiverEpochMarker() => + finished = true + null + } + + override def close(): Unit = {} + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index daa2963220aef..b137f98045c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -156,7 +156,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal)) newBlocks.map { block => - new MemoryStreamDataReaderFactory(block).asInstanceOf[InputPartition[UnsafeRow]] + new MemoryStreamInputPartition(block).asInstanceOf[InputPartition[UnsafeRow]] }.asJava } } @@ -201,7 +201,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } -class MemoryStreamDataReaderFactory(records: Array[UnsafeRow]) +class MemoryStreamInputPartition(records: Array[UnsafeRow]) extends InputPartition[UnsafeRow] { override def createPartitionReader(): InputPartitionReader[UnsafeRow] = { new InputPartitionReader[UnsafeRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index fef792eab69d5..d1c3498450096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -44,13 +44,12 @@ import org.apache.spark.util.RpcUtils * * ContinuousMemoryStream maintains a list of records for each partition. addData() will * distribute records evenly-ish across partitions. * * RecordEndpoint is set up as an endpoint for executor-side - * ContinuousMemoryStreamDataReader instances to poll. It returns the record at the specified - * offset within the list, or null if that offset doesn't yet have a record. + * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at + * the specified offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) +class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) extends MemoryStreamBase[A](sqlContext) with ContinuousReader with ContinuousReadSupport { private implicit val formats = Serialization.formats(NoTypeHints) - private val NUM_PARTITIONS = 2 protected val logicalPlan = StreamingRelationV2(this, "memory", Map(), attributes, None)(sqlContext.sparkSession) @@ -58,7 +57,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) // ContinuousReader implementation @GuardedBy("this") - private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A]) + private val records = Seq.fill(numPartitions)(new ListBuffer[A]) @GuardedBy("this") private var startOffset: ContinuousMemoryStreamOffset = _ @@ -69,17 +68,17 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = synchronized { // Distribute data evenly among partition lists. data.toSeq.zipWithIndex.map { - case (item, index) => records(index % NUM_PARTITIONS) += item + case (item, index) => records(index % numPartitions) += item } // The new target offset is the offset where all records in all partitions have been processed. - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, records(i).size)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, records(i).size)).toMap) } override def setStartOffset(start: Optional[Offset]): Unit = synchronized { // Inferred initial offset is position 0 in each partition. startOffset = start.orElse { - ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 0)).toMap) + ContinuousMemoryStreamOffset((0 until numPartitions).map(i => (i, 0)).toMap) }.asInstanceOf[ContinuousMemoryStreamOffset] } @@ -107,7 +106,7 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) startOffset.partitionNums.map { case (part, index) => - new ContinuousMemoryStreamDataReaderFactory( + new ContinuousMemoryStreamInputPartition( endpointName, part, index): InputPartition[Row] }.toList.asJava } @@ -152,12 +151,15 @@ object ContinuousMemoryStream { def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + + def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) } /** - * Data reader factory for continuous memory stream. + * An input partition for continuous memory stream. */ -class ContinuousMemoryStreamDataReaderFactory( +class ContinuousMemoryStreamInputPartition( driverEndpointName: String, partition: Int, startOffset: Int) extends InputPartition[Row] { @@ -166,7 +168,7 @@ class ContinuousMemoryStreamDataReaderFactory( } /** - * Data reader for continuous memory stream. + * An input partition reader for continuous memory stream. * * Polls the driver endpoint for new records. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala index 723cc3ad5bb89..fbff8db987110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -167,7 +167,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: } (0 until numPartitions).map { p => - new RateStreamMicroBatchDataReaderFactory( + new RateStreamMicroBatchInputPartition( p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) : InputPartition[Row] }.toList.asJava @@ -182,7 +182,7 @@ class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" } -class RateStreamMicroBatchDataReaderFactory( +class RateStreamMicroBatchInputPartition( partitionId: Int, numPartitions: Int, rangeStart: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 01d8e75980993..3f11b8f79943c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -71,8 +72,15 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( StateStoreId(checkpointLocation, operatorId, partition.index), queryRunId) + // If we're in continuous processing mode, we should get the store version for the current + // epoch rather than the one at planning time. + val currentVersion = EpochTracker.getCurrentEpoch match { + case None => storeVersion + case Some(value) => value + } + store = StateStore.get( - storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion, + storeProviderId, keySchema, valueSchema, indexOrdinal, currentVersion, storeConf, hadoopConfBroadcast.value.value) val inputIter = dataRDD.iterator(partition, ctxt) storeUpdateFunction(store, inputIter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 582528777f90e..bf46bc4cf904d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -58,21 +58,21 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L _content ++= new RunningExecutionTable( parent, s"Running Queries (${running.size})", currentTime, - running.sortBy(_.submissionTime).reverse).toNodeSeq + running.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (completed.nonEmpty) { _content ++= new CompletedExecutionTable( parent, s"Completed Queries (${completed.size})", currentTime, - completed.sortBy(_.submissionTime).reverse).toNodeSeq + completed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } if (failed.nonEmpty) { _content ++= new FailedExecutionTable( parent, s"Failed Queries (${failed.size})", currentTime, - failed.sortBy(_.submissionTime).reverse).toNodeSeq + failed.sortBy(_.submissionTime).reverse).toNodeSeq(request) } _content } @@ -111,7 +111,7 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } - UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "SQL", summary ++ content, parent, Some(5000)) } } @@ -133,7 +133,10 @@ private[ui] abstract class ExecutionTable( protected def header: Seq[String] - protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + protected def row( + request: HttpServletRequest, + currentTime: Long, + executionUIData: SQLExecutionUIData): Seq[Node] = { val submissionTime = executionUIData.submissionTime val duration = executionUIData.completionTime.map(_.getTime()).getOrElse(currentTime) - submissionTime @@ -141,7 +144,7 @@ private[ui] abstract class ExecutionTable( def jobLinks(status: JobExecutionStatus): Seq[Node] = { executionUIData.jobs.flatMap { case (jobId, jobStatus) => if (jobStatus == status) { - [{jobId.toString}] + [{jobId.toString}] } else { None } @@ -153,7 +156,7 @@ private[ui] abstract class ExecutionTable( {executionUIData.executionId.toString} - {descriptionCell(executionUIData)} + {descriptionCell(request, executionUIData)} {UIUtils.formatDate(submissionTime)} @@ -179,7 +182,9 @@ private[ui] abstract class ExecutionTable( } - private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + private def descriptionCell( + request: HttpServletRequest, + execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details != null && execution.details.nonEmpty) { +details @@ -192,27 +197,28 @@ private[ui] abstract class ExecutionTable( } val desc = if (execution.description != null && execution.description.nonEmpty) { - {execution.description} + {execution.description} } else { - {execution.executionId} + {execution.executionId} }
    {desc} {details}
    } - def toNodeSeq: Seq[Node] = { + def toNodeSeq(request: HttpServletRequest): Seq[Node] = {

    {tableName}

    {UIUtils.listingTable[SQLExecutionUIData]( - header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + header, row(request, currentTime, _), executionUIDatas, id = Some(tableId))}
    } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) - private def executionURL(executionID: Long): String = - s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" + private def executionURL(request: HttpServletRequest, executionID: Long): String = + s"${UIUtils.prependBaseUri( + request, parent.basePath)}/${parent.prefix}/execution?id=$executionID" } private[ui] class RunningExecutionTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index e0554f0c4d337..282f7b4bb5a58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -49,7 +49,7 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
  • {label} {jobs.toSeq.sorted.map { jobId => - {jobId.toString}  + {jobId.toString}  }}
  • } else { @@ -77,27 +77,31 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging val graph = sqlStore.planGraph(executionId) summary ++ - planVisualization(metrics, graph) ++ + planVisualization(request, metrics, graph) ++ physicalPlanDescription(executionUIData.physicalPlanDescription) }.getOrElse {
    No information to display for query {executionId}
    } - UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + UIUtils.headerSparkPage( + request, s"Details for Query $executionId", content, parent, Some(5000)) } - private def planVisualizationResources: Seq[Node] = { + private def planVisualizationResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - - - + + + + + // scalastyle:on } - private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = { + private def planVisualization( + request: HttpServletRequest, + metrics: Map[Long, String], + graph: SparkPlanGraph): Seq[Node] = { val metadata = graph.allNodes.flatMap { node => val nodeId = s"plan-meta-data-${node.id}"
    {node.desc}
    @@ -112,13 +116,13 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
    {graph.allNodes.size.toString}
    {metadata} - {planVisualizationResources} + {planVisualizationResources(request)} } - private def jobURL(jobId: Long): String = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + private def jobURL(request: HttpServletRequest, jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(request, parent.basePath), jobId) private def physicalPlanDescription(physicalPlanDescription: String): Seq[Node] = {
    diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ff53367426f0d..443ba2aa3757d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -811,177 +811,6 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: Column, x: Column): Column = withAggregateFunction { - RegrCount(y.expr, x.expr) - } - - /** - * Aggregate function: returns the number of non-null pairs. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: Column, x: Column): Column = withAggregateFunction { - RegrSXX(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x)) - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: Column, x: Column): Column = withAggregateFunction { - RegrSYY(y.expr, x.expr) - } - - /** - * Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of y. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x)) - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: Column, x: Column): Column = withAggregateFunction { - RegrAvgX(y.expr, x.expr) - } - - /** - * Aggregate function: returns the average of x. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x)) - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: Column, x: Column): Column = withAggregateFunction { - RegrSXY(y.expr, x.expr) - } - - /** - * Aggregate function: returns the covariance of y and x multiplied for the number of items in - * the dataset. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x)) - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: Column, x: Column): Column = withAggregateFunction { - RegrSlope(y.expr, x.expr) - } - - /** - * Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is - * ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x)) - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: Column, x: Column): Column = withAggregateFunction { - RegrR2(y.expr, x.expr) - } - - /** - * Aggregate function: returns the coefficient of determination (also called R-squared or - * goodness of fit) for the regression line. Any pair with a NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x)) - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: Column, x: Column): Column = withAggregateFunction { - RegrIntercept(y.expr, x.expr) - } - - /** - * Aggregate function: returns the y-intercept of the linear regression line. Any pair with a - * NULL is ignored. - * - * @group agg_funcs - * @since 2.4.0 - */ - def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x)) - - ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions @@ -2903,7 +2732,12 @@ object functions { /** * Returns number of months between dates `date1` and `date2`. - * The result is rounded off to 8 digits. + * If `date1` is later than `date2`, then the result is positive. + * If `date1` and `date2` are on the same day of month, or both are the last day of month, + * time of day will be ignored. + * + * Otherwise, the difference is calculated based on 31 days per month, and rounded to + * 8 digits. * @group datetime_funcs * @since 1.5.0 */ @@ -3251,6 +3085,17 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both + * the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns + * `false` otherwise. + * @group collection_funcs + * @since 2.4.0 + */ + def arrays_overlap(a1: Column, a2: Column): Column = withExpr { + ArraysOverlap(a1.expr, a2.expr) + } + /** * Returns an array containing all the elements in `x` from index `start` (or starting from the * end if `start` is negative) with the specified `length`. @@ -3397,9 +3242,9 @@ object functions { from_json(e, schema.asInstanceOf[DataType], options) /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3429,9 +3274,9 @@ object functions { from_json(e, schema, options.asScala.toMap) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3458,8 +3303,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` or `ArrayType` of `StructType`s - * with the specified schema. Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `MapType` with `StringType` as keys type, + * `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string @@ -3471,9 +3317,9 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string. In Spark 2.1, @@ -3488,9 +3334,9 @@ object functions { } /** - * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` - * of `StructType`s with the specified schema. Returns `null`, in the case of an unparseable - * string. + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string, it could be a @@ -3612,6 +3458,26 @@ object functions { */ def flatten(e: Column): Column = withExpr { Flatten(e.expr) } + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(left: Column, right: Column): Column = withExpr { + ArrayRepeat(left.expr, right.expr) + } + + /** + * Creates an array containing the left argument repeated the number of times given by the + * right argument. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_repeat(e: Column, count: Int): Column = array_repeat(e, lit(count)) + /** * Returns an unordered array containing the keys of the map. * @group collection_funcs @@ -3626,6 +3492,13 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } + /** + * Returns an unordered array of all entries in the given map. + * @group collection_funcs + * @since 2.4.0 + */ + def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } + ////////////////////////////////////////////////////////////////////////////////////////////// // Mask functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7cefd03e43bc3..97da2b1325f58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -242,7 +242,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo (sink, trigger) match { case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => - UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { + UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) + } new StreamingQueryWrapper(new ContinuousExecution( sparkSession, userSpecifiedName.orNull, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 714638e500c94..445cb29f5ee3a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -107,7 +107,8 @@ public List> planInputPartitions() { } } - static class JavaAdvancedInputPartition implements InputPartition, InputPartitionReader { + static class JavaAdvancedInputPartition implements InputPartition, + InputPartitionReader { private int start; private int end; private StructType requiredSchema; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index 949505e449fd7..276496be3d62c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,9 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id.desc) + // Range has range partitioning in its output now. To have a range shuffle, we + // need to run a repartition first. + val data = spark.range(0, n, 1, 1).repartition(10).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4337fb2290fbc..96c28961e5aaf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -687,72 +687,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } } - - test("SPARK-23907: regression functions") { - val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b") - val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12)) - .toDF("a", "b") - val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)]( - (2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b") - checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6))) - checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1))) - checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0))) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol) - checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol) - - checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol) - checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol) - checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")), - Row(null), absTol) - - - checkAggregatesWithTol(correlatedData.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092), - absTol) - checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg( - regr_count("a", "b"), - regr_avgx("a", "b"), - regr_avgy("a", "b"), - regr_sxx("a", "b"), - regr_syy("a", "b"), - regr_sxy("a", "b"), - regr_slope("a", "b"), - regr_r2("a", "b"), - regr_intercept("a", "b")), - Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149), - absTol) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8aad03a2d0222..cdf36ef65ee28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -512,6 +512,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("map_entries") { + val dummyFilter = (c: Column) => c.isNotNull || c.isNull + + // Primitive-type elements + val idf = Seq( + Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), + Map[Int, Int](), + null + ).toDF("m") + val iExpected = Seq( + Row(Seq(Row(1, 100), Row(2, 200), Row(3, 300))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(idf.select(map_entries('m)), iExpected) + checkAnswer(idf.selectExpr("map_entries(m)"), iExpected) + checkAnswer(idf.filter(dummyFilter('m)).select(map_entries('m)), iExpected) + checkAnswer( + spark.range(1).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + checkAnswer( + spark.range(1).filter(dummyFilter('id)).selectExpr("map_entries(map(1, null, 2, null))"), + Seq(Row(Seq(Row(1, null), Row(2, null))))) + + // Non-primitive-type elements + val sdf = Seq( + Map[String, String]("a" -> "f", "b" -> "o", "c" -> "o"), + Map[String, String]("a" -> null, "b" -> null), + Map[String, String](), + null + ).toDF("m") + val sExpected = Seq( + Row(Seq(Row("a", "f"), Row("b", "o"), Row("c", "o"))), + Row(Seq(Row("a", null), Row("b", null))), + Row(Seq.empty), + Row(null) + ) + + checkAnswer(sdf.select(map_entries('m)), sExpected) + checkAnswer(sdf.selectExpr("map_entries(m)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x"), @@ -549,6 +593,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("arrays_overlap function") { + val df = Seq( + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))), + (Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)), + (Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2))) + ).toDF("a", "b") + + val answer = Seq(Row(false), Row(null), Row(true)) + + checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer) + checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer) + + checkAnswer( + Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"), + Row(true)) + + intercept[AnalysisException] { + sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(null, null)") + } + + intercept[AnalysisException] { + sql("select arrays_overlap(map(1, 2), map(3, 4))") + } + } + test("slice function") { val df = Seq( Seq(1, 2, 3), @@ -790,6 +863,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("element_at(a, -1)"), Seq(Row("3"), Row(""), Row(null)) ) + + val e = intercept[AnalysisException] { + Seq(("a string element", 1)).toDF().selectExpr("element_at(_1, _2)") + } + assert(e.message.contains( + "argument 1 requires (array or map) type, however, '`_1`' is of string type")) } test("concat function - arrays") { @@ -950,6 +1029,82 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } } + test("array_repeat function") { + val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on + val strDF = Seq( + ("hi", 2), + (null, 2) + ).toDF("a", "b") + + val strDFTwiceResult = Seq( + Row(Seq("hi", "hi")), + Row(Seq(null, null)) + ) + + checkAnswer(strDF.select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), strDFTwiceResult) + checkAnswer(strDF.select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, 2)"), strDFTwiceResult) + checkAnswer(strDF.selectExpr("array_repeat(a, b)"), strDFTwiceResult) + + val intDF = { + val schema = StructType(Seq( + StructField("a", IntegerType), + StructField("b", IntegerType))) + val data = Seq( + Row(3, 2), + Row(null, 2) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + val intDFTwiceResult = Seq( + Row(Seq(3, 3)), + Row(Seq(null, null)) + ) + + checkAnswer(intDF.select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", 2)), intDFTwiceResult) + checkAnswer(intDF.select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.filter(dummyFilter($"a")).select(array_repeat($"a", $"b")), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, 2)"), intDFTwiceResult) + checkAnswer(intDF.selectExpr("array_repeat(a, b)"), intDFTwiceResult) + + val nullCountDF = { + val schema = StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))) + val data = Seq( + Row("hi", null), + Row(null, null) + ) + spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + } + + checkAnswer( + nullCountDF.select(array_repeat($"a", $"b")), + Seq( + Row(null), + Row(null) + ) + ) + + // Error test cases + val invalidTypeDF = Seq(("hi", "1")).toDF("a", "b") + + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", $"b")) + } + intercept[AnalysisException] { + invalidTypeDF.select(array_repeat($"a", lit("1"))) + } + intercept[AnalysisException] { + invalidTypeDF.selectExpr("array_repeat(a, 1.0)") + } + + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee7504..1cc8cb3874c9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-24313: access map with binary keys") { + val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1)) + checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e0f4d2ba685e1..d477d78dc14e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1425,6 +1425,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-23627: provide isEmpty in DataSet") { + val ds1 = spark.emptyDataset[Int] + val ds2 = Seq(1, 2, 3).toDS() + + assert(ds1.isEmpty == true) + assert(ds2.isEmpty == false) + } + test("SPARK-22472: add null check for top-level primitive values") { // If the primitive values are from Option, we need to do runtime null check. val ds = Seq(Some(1), None).toDS().as[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 109fcf90a3ec9..8280a3ce39845 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Generator} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} @@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator { override def eval(input: InternalRow): TraversableOnce[InternalRow] = Seq.empty override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val iteratorClass = classOf[Iterator[_]].getName - ev.copy(code = s"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") + ev.copy(code = + code"$iteratorClass ${ev.value} = $iteratorClass$$.MODULE$$.empty();") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 00d2acc4a1d8a..055e1fc5640f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -326,4 +326,70 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { assert(errMsg4.getMessage.startsWith( "A type of keys and values in map() must be string, but got")) } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": 1, "b": 2, "c": 3}""").toDS() + val schema = + """ + |{ + | "type" : "map", + | "keyType" : "string", + | "valueType" : "integer", + | "valueContainsNull" : true + |} + """.stripMargin + val out = in.select(from_json($"value", schema, Map[String, String]())) + + assert(out.columns.head == "entries") + checkAnswer(out, Row(Map("a" -> 1, "b" -> 2, "c" -> 3))) + } + + test("SPARK-24027: from_json - map") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, new StructType().add("b", IntegerType), true) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Row(1)))) + } + + test("SPARK-24027: from_json - map>") { + val in = Seq("""{"a": {"b": 1}}""").toDS() + val schema = MapType(StringType, MapType(StringType, IntegerType)) + val out = in.select(from_json($"value", schema)) + + checkAnswer(out, Row(Map("a" -> Map("b" -> 1)))) + } + + test("SPARK-24027: roundtrip - from_json -> to_json - map") { + val json = """{"a":1,"b":2,"c":3}""" + val schema = MapType(StringType, IntegerType, true) + val out = Seq(json).toDS().select(to_json(from_json($"value", schema))) + + checkAnswer(out, Row(json)) + } + + test("SPARK-24027: roundtrip - to_json -> from_json - map") { + val in = Seq(Map("a" -> 1)).toDF() + val schema = MapType(StringType, IntegerType, true) + val out = in.select(from_json(to_json($"value"), schema)) + + checkAnswer(out, in) + } + + test("SPARK-24027: from_json - wrong map") { + val in = Seq("""{"a" 1}""").toDS() + val schema = MapType(StringType, IntegerType) + val out = in.select(from_json($"value", schema, Map[String, String]())) + + checkAnswer(out, Row(null)) + } + + test("SPARK-24027: from_json of a map with unsupported key type") { + val schema = MapType(StructType(StructField("f", IntegerType) :: Nil), StringType) + + checkAnswer(Seq("""{{"f": 1}: "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), + Row(null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f0dfe6b76f7ae..b2aba8e72c5db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -197,6 +197,18 @@ class PlannerSuite extends SharedSQLContext { assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } + test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { + withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { + val query0 = testData.select('value).orderBy('key).limit(100) + val planned0 = query0.queryExecution.executedPlan + assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) + + val query1 = testData.select('value).orderBy('key).limit(2000) + val planned1 = query1.queryExecution.executedPlan + assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) + } + } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { val query = testData.select('key, 'value).sort('key.desc).cache() assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) @@ -621,6 +633,31 @@ class PlannerSuite extends SharedSQLContext { requiredOrdering = Seq(orderingA, orderingB), shouldHaveSort = true) } + + test("SPARK-24242: RangeExec should have correct output ordering and partitioning") { + val df = spark.range(10) + val rangeExec = df.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + val range = df.queryExecution.optimizedPlan.collect { + case r: Range => r + } + assert(rangeExec.head.outputOrdering == range.head.outputOrdering) + assert(rangeExec.head.outputPartitioning == + RangePartitioning(rangeExec.head.outputOrdering, df.rdd.getNumPartitions)) + + val rangeInOnePartition = spark.range(1, 10, 1, 1) + val rangeExecInOnePartition = rangeInOnePartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInOnePartition.head.outputPartitioning == SinglePartition) + + val rangeInZeroPartition = spark.range(-10, -9, -20, 1) + val rangeExecInZeroPartition = rangeInZeroPartition.queryExecution.executedPlan.collect { + case r: RangeExec => r + } + assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) + } } // Used for unit-testing EnsureRequirements diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9180a22c260f1..b714dcd5269fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -51,12 +51,12 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { } test("Aggregate with grouping keys should be included in WholeStageCodegen") { - val df = spark.range(3).groupBy("id").count().orderBy("id") + val df = spark.range(3).groupBy(col("id") * 2).count().orderBy(col("id") * 2) val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + assert(df.collect() === Array(Row(0, 1), Row(2, 1), Row(4, 1))) } test("BroadcastHashJoin should be included in WholeStageCodegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala new file mode 100644 index 0000000000000..a39a25be262a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReaderSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.SharedSQLContext + +class HadoopFileLinesReaderSuite extends SharedSQLContext { + def getLines( + path: File, + text: String, + ranges: Seq[(Long, Long)], + delimiter: Option[String] = None, + conf: Option[Configuration] = None): Seq[String] = { + val delimOpt = delimiter.map(_.getBytes(StandardCharsets.UTF_8)) + Files.write(path.toPath, text.getBytes(StandardCharsets.UTF_8)) + + val lines = ranges.map { case (start, length) => + val file = PartitionedFile(InternalRow.empty, path.getCanonicalPath, start, length) + val hadoopConf = conf.getOrElse(spark.sparkContext.hadoopConfiguration) + val reader = new HadoopFileLinesReader(file, delimOpt, hadoopConf) + + reader.map(_.toString) + }.flatten + + lines + } + + test("A split ends at the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 1), (1, 3))) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 2), (2, 2))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the end of the delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 3), (3, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split covers two lines") { + withTempPath { path => + val lines = getLines(path, text = "a\r\nb", ranges = Seq((0, 4), (4, 1))) + assert(lines == Seq("a", "b")) + } + } + + test("A split ends at the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 1), (1, 4)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split slices the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 2), (2, 3)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("The first split covers the first line and the custom delimiter") { + withTempPath { path => + val lines = getLines(path, text = "a^_^b", ranges = Seq((0, 4), (4, 1)), Some("^_^")) + assert(lines == Seq("a", "b")) + } + } + + test("A split cuts the first line") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((0, 1)), Some(",")) + assert(lines == Seq("abc")) + } + } + + test("The split cuts both lines") { + withTempPath { path => + val lines = getLines(path, text = "abc,def", ranges = Seq((2, 2)), Some(",")) + assert(lines == Seq("def")) + } + } + + test("io.file.buffer.size is less than line length") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("io.file.buffer.size", "2") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n123456", ranges = Seq((4, 4), (8, 5))) + assert(lines == Seq("123456")) + } + } + + test("line cannot be longer than line.maxlength") { + val conf = spark.sparkContext.hadoopConfiguration + conf.set("mapreduce.input.linerecordreader.line.maxlength", "5") + withTempPath { path => + val lines = getLines(path, text = "abcdef\n1234", ranges = Seq((0, 15))) + assert(lines == Seq("1234")) + } + } + + test("default delimiter is 0xd or 0xa or 0xd0xa") { + withTempPath { path => + val lines = getLines(path, text = "1\r2\n3\r\n4", ranges = Seq((0, 3), (3, 5))) + assert(lines == Seq("1", "2", "3", "4")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala index 4b3ca8e60cab6..a1da3ec43eae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommandSuite.scala @@ -23,9 +23,6 @@ import org.apache.spark.sql.test.SharedSQLContext class SaveIntoDataSourceCommandSuite extends SharedSQLContext { - override protected def sparkConf: SparkConf = super.sparkConf - .set("spark.redaction.regex", "(?i)password|url") - test("simpleString is redacted") { val URL = "connection.url" val PASS = "123" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala new file mode 100644 index 0000000000000..ec788df00aa92 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVBenchmarks.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import java.io.File + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.apache.spark.util.{Benchmark, Utils} + +/** + * Benchmark to measure CSV read/write performance. + * To run this: + * spark-submit --class --jars + */ +object CSVBenchmarks { + val conf = new SparkConf() + + val spark = SparkSession.builder + .master("local[1]") + .appName("benchmark-csv-datasource") + .config(conf) + .getOrCreate() + import spark.implicits._ + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def quotedValuesBenchmark(rowsNum: Int, numIters: Int): Unit = { + val benchmark = new Benchmark(s"Parsing quoted values", rowsNum) + + withTempPath { path => + val str = (0 until 10000).map(i => s""""$i"""").mkString(",") + + spark.range(rowsNum) + .map(_ => str) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val schema = new StructType().add("value", StringType) + val ds = spark.read.option("header", true).schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"One quoted string", numIters) { _ => + ds.filter((_: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Parsing quoted values: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + One quoted string 30273 / 30549 0.0 605451.2 1.0X + */ + benchmark.run() + } + } + + def multiColumnsBenchmark(rowsNum: Int): Unit = { + val colsNum = 1000 + val benchmark = new Benchmark(s"Wide rows with $colsNum columns", rowsNum) + + withTempPath { path => + val fields = Seq.tabulate(colsNum)(i => StructField(s"col$i", IntegerType)) + val schema = StructType(fields) + val values = (0 until colsNum).map(i => i.toString).mkString(",") + val columnNames = schema.fieldNames + + spark.range(rowsNum) + .select(Seq.tabulate(colsNum)(i => lit(i).as(s"col$i")): _*) + .write.option("header", true) + .csv(path.getAbsolutePath) + + val ds = spark.read.schema(schema).csv(path.getAbsolutePath) + + benchmark.addCase(s"Select $colsNum columns", 3) { _ => + ds.select("*").filter((row: Row) => true).count() + } + val cols100 = columnNames.take(100).map(Column(_)) + benchmark.addCase(s"Select 100 columns", 3) { _ => + ds.select(cols100: _*).filter((row: Row) => true).count() + } + benchmark.addCase(s"Select one column", 3) { _ => + ds.select($"col1").filter((row: Row) => true).count() + } + + /* + Intel(R) Core(TM) i7-7920HQ CPU @ 3.10GHz + + Wide rows with 1000 columns: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + -------------------------------------------------------------------------------------------- + Select 1000 columns 76910 / 78065 0.0 76909.8 1.0X + Select 100 columns 28625 / 32884 0.0 28625.1 2.7X + Select one column 22498 / 22669 0.0 22497.8 3.4X + */ + benchmark.run() + } + } + + def main(args: Array[String]): Unit = { + quotedValuesBenchmark(rowsNum = 50 * 1000, numIters = 3) + multiColumnsBenchmark(rowsNum = 1000 * 1000) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 461abdd96d3f3..5f9f799a6c466 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -260,14 +260,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te } test("test for DROPMALFORMED parsing mode") { - Seq(false, true).foreach { multiLine => - val cars = spark.read - .format("csv") - .option("multiLine", multiLine) - .options(Map("header" -> "true", "mode" -> "dropmalformed")) - .load(testFile(carsFile)) + withSQLConf(SQLConf.CSV_PARSER_COLUMN_PRUNING.key -> "false") { + Seq(false, true).foreach { multiLine => + val cars = spark.read + .format("csv") + .option("multiLine", multiLine) + .options(Map("header" -> "true", "mode" -> "dropmalformed")) + .load(testFile(carsFile)) - assert(cars.select("year").collect().size === 2) + assert(cars.select("year").collect().size === 2) + } } } @@ -1322,4 +1324,77 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te val sampled = spark.read.option("inferSchema", true).option("samplingRatio", 1.0).csv(ds) assert(sampled.count() == ds.count()) } + + test("SPARK-17916: An empty string should not be coerced to null when nullValue is passed.") { + val litNull: String = null + val df = Seq( + (1, "John Doe"), + (2, ""), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + // Checks for new behavior where an empty string is not coerced to null when `nullValue` is + // set to anything but an empty string literal. + withTempPath { path => + df.write + .option("nullValue", "-") + .csv(path.getAbsolutePath) + val computed = spark.read + .option("nullValue", "-") + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, ""), + (3, litNull), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + // Keeps the old behavior where empty string us coerced to nullValue is not passed. + withTempPath { path => + df.write + .csv(path.getAbsolutePath) + val computed = spark.read + .schema(df.schema) + .csv(path.getAbsolutePath) + val expected = Seq( + (1, "John Doe"), + (2, litNull), + (3, "-"), + (4, litNull) + ).toDF("id", "name") + + checkAnswer(computed, expected) + } + } + + test("SPARK-24244: Select a subset of all columns") { + withTempPath { path => + import collection.JavaConverters._ + val schema = new StructType() + .add("f1", IntegerType).add("f2", IntegerType).add("f3", IntegerType) + .add("f4", IntegerType).add("f5", IntegerType).add("f6", IntegerType) + .add("f7", IntegerType).add("f8", IntegerType).add("f9", IntegerType) + .add("f10", IntegerType).add("f11", IntegerType).add("f12", IntegerType) + .add("f13", IntegerType).add("f14", IntegerType).add("f15", IntegerType) + + val odf = spark.createDataFrame(List( + Row(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), + Row(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15) + ).asJava, schema) + odf.write.csv(path.getCanonicalPath) + val idf = spark.read + .schema(schema) + .csv(path.getCanonicalPath) + .select('f15, 'f10, 'f5) + + checkAnswer( + idf, + List(Row(15, 10, 5), Row(-15, -10, -5)) + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 34d23ee53220d..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) + var actual = compatibleType(t1, t2) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) + actual = compatibleType(t2, t1) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 4d0ecdef60986..90da7eb8c4fb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -650,13 +650,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } test("SPARK-23852: Broken Parquet push-down for partially-written stats") { - // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. - // The row-group statistics include null counts, but not min and max values, which - // triggers PARQUET-1217. - val df = readResourceParquetFile("test-data/parquet-1217.parquet") + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + // parquet-1217.parquet contains a single column with values -1, 0, 1, 2 and null. + // The row-group statistics include null counts, but not min and max values, which + // triggers PARQUET-1217. + val df = readResourceParquetFile("test-data/parquet-1217.parquet") - // Will return 0 rows if PARQUET-1217 is not fixed. - assert(df.where("col > 0").count() === 2) + // Will return 0 rows if PARQUET-1217 is not fixed. + assert(df.where("col > 0").count() === 2) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index adcaf2d76519f..8251ff159e05f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.TestData @@ -33,14 +34,16 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) } test("debugCodegenStringSeq") { - val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count() + .queryExecution.executedPlan) assert(res.length == 2) assert(res.forall{ case (subtree, code) => subtree.contains("Range") && code.contains("Object[]")}) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 39a010f970ce5..bf72e5c99689f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -309,7 +309,7 @@ class RateSourceSuite extends StreamTest { val data = scala.collection.mutable.ListBuffer[Row]() tasks.asScala.foreach { - case t: RateStreamContinuousDataReaderFactory => + case t: RateStreamContinuousInputPartition => val startTimeMs = reader.getStartOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..3dd0712e02448 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadOnlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 5238adce4a699..bc2aca65e803f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -1190,4 +1190,20 @@ class JDBCSuite extends SparkFunSuite assert(sql("select * from people_view").schema === schema) } } + + test("SPARK-23856 Spark jdbc setQueryTimeout option") { + val numJoins = 100 + val longRunningQuery = + s"SELECT t0.NAME AS c0, ${(1 to numJoins).map(i => s"t$i.NAME AS c$i").mkString(", ")} " + + s"FROM test.people t0 ${(1 to numJoins).map(i => s"join test.people t$i").mkString(" ")}" + val df = spark.read.format("jdbc") + .option("Url", urlWithUserAndPass) + .option("dbtable", s"($longRunningQuery)") + .option("queryTimeout", 1) + .load() + val errMsg = intercept[SparkException] { + df.collect() + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index 1985b1dc82879..1c2c92d1f0737 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -515,4 +515,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(e.contains("NULL not allowed for column \"NAME\"")) } + + ignore("SPARK-23856 Spark jdbc setQueryTimeout option") { + // The behaviour of the option `queryTimeout` depends on how JDBC drivers implement the API + // `setQueryTimeout`. For example, in the h2 JDBC driver, `executeBatch` invokes multiple + // INSERT queries in a batch and `setQueryTimeout` means that the driver checks the timeout + // of each query. In the PostgreSQL JDBC driver, `setQueryTimeout` means that the driver + // checks the timeout of an entire batch in a driver side. So, the test below fails because + // this test suite depends on the h2 JDBC driver and the JDBC write path internally + // uses `executeBatch`. + val errMsg = intercept[SparkException] { + spark.range(10000000L).selectExpr("id AS k", "id AS v").coalesce(1).write + .mode(SaveMode.Overwrite) + .option("queryTimeout", 1) + .option("batchsize", Int.MaxValue) + .jdbc(url1, "TEST.TIMEOUTTEST", properties) + }.getMessage + assert(errMsg.contains("Statement was canceled or the session timed out")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index b1416bff87ee7..988c8e6753e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -615,20 +615,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -657,15 +657,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) testStream(result, Update)( AddData(inputData, "a", "a", "b"), - CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "1"), ("a", "2"), ("b", "1")), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch(("b", "2")), + CheckNewAnswer(("b", "2")), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) + CheckNewAnswer(("a", "1"), ("c", "1")) ) } @@ -694,22 +694,22 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Complete)( AddData(inputData, "a"), - CheckLastBatch(("a", 1)), + CheckNewAnswer(("a", 1)), AddData(inputData, "a", "b"), // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 - CheckLastBatch(("a", 2), ("b", 1)), + CheckNewAnswer(("a", 2), ("b", 1)), StopStream, StartStream(), AddData(inputData, "a", "b"), // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; // so increment a and b by 1 - CheckLastBatch(("a", 3), ("b", 2)), + CheckNewAnswer(("a", 3), ("b", 2)), StopStream, StartStream(), AddData(inputData, "a", "c"), // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; // so increment a and c by 1 - CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + CheckNewAnswer(("a", 4), ("b", 2), ("c", 1)) ) } @@ -729,8 +729,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest } test("flatMapGroupsWithState - streaming with processing time timeout") { - // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // Function to maintain the count as state and set the proc. time timeout delay of 10 seconds. + // It returns the count if changed, or -1 if the state was removed by timeout. val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCannotGetWatermark { state.getCurrentWatermarkMs() } @@ -757,17 +757,17 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, "a"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "b"), AdvanceManualClock(1 * 1000), - CheckLastBatch(("b", "1")), + CheckNewAnswer(("b", "1")), assertNumStateRows(total = 2, updated = 1), AddData(inputData, "b"), AdvanceManualClock(10 * 1000), - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, @@ -775,38 +775,42 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest AddData(inputData, "c"), AdvanceManualClock(11 * 1000), - CheckLastBatch(("b", "-1"), ("c", "1")), + CheckNewAnswer(("b", "-1"), ("c", "1")), assertNumStateRows(total = 1, updated = 2), - AddData(inputData, "c"), - AdvanceManualClock(20 * 1000), - CheckLastBatch(("c", "2")), - assertNumStateRows(total = 1, updated = 1) + AdvanceManualClock(12 * 1000), + AssertOnQuery { _ => clock.getTimeMillis() == 35000 }, + Execute { q => + failAfter(streamingTimeout) { + while (q.lastProgress.timestamp != "1970-01-01T00:00:35.000Z") { + Thread.sleep(1) + } + } + }, + CheckNewAnswer(("c", "-1")), + assertNumStateRows(total = 0, updated = 0) ) } test("flatMapGroupsWithState - streaming with event time timeout + watermark") { - // Function to maintain the max event time - // Returns the max event time in the state, or -1 if the state was removed by timeout + // Function to maintain the max event time as state and set the timeout timestamp based on the + // current max event time seen. It returns the max event time in the state, or -1 if the state + // was removed by timeout. val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } - val timeoutDelay = 5 - if (key != "a") { - Iterator.empty + val timeoutDelaySec = 5 + if (state.hasTimedOut) { + state.remove() + Iterator((key, -1)) } else { - if (state.hasTimedOut) { - state.remove() - Iterator((key, -1)) - } else { - val valuesSeq = values.toSeq - val maxEventTime = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) - val timeoutTimestampMs = maxEventTime + timeoutDelay - state.update(maxEventTime) - state.setTimeoutTimestamp(timeoutTimestampMs * 1000) - Iterator((key, maxEventTime.toInt)) - } + val valuesSeq = values.toSeq + val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, state.getOption.getOrElse(0L)) + val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec + state.update(maxEventTimeSec) + state.setTimeoutTimestamp(timeoutTimestampSec * 1000) + Iterator((key, maxEventTimeSec.toInt)) } } val inputData = MemoryStream[(String, Int)] @@ -819,15 +823,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc) testStream(result, Update)( - StartStream(Trigger.ProcessingTime("1 second")), - AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), // Set timeout timestamp of ... - CheckLastBatch(("a", 15)), // "a" to 15 + 5 = 20s, watermark to 5s + StartStream(), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" - CheckLastBatch(), // No output as data should get filtered by watermark - AddData(inputData, ("dummy", 35)), // Set watermark = 35 - 10 = 25s - CheckLastBatch(), // No output as no data for "a" - AddData(inputData, ("a", 24)), // Add data older than watermark, should be ignored - CheckLastBatch(("a", -1)) // State for "a" should timeout and emit -1 + CheckNewAnswer(), // No output as data should get filtered by watermark + + AddData(inputData, ("a", 10)), // Add data newer than watermark for "a" + CheckNewAnswer(("a", 15)), // Max event time is still the same + // Timeout timestamp for "a" is still 20 as max event time for "a" is still 15. + // Watermark is still 5 as max event time for all data is still 15. + + AddData(inputData, ("b", 31)), // Add data newer than watermark for "b", not "a" + // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout timestamp for "a" is 20. + CheckNewAnswer(("a", -1), ("b", 31)) // State for "a" should timeout and emit -1 ) } @@ -856,20 +868,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch(("a", "1")), + CheckNewAnswer(("a", "1")), assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")), + CheckNewAnswer(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + CheckNewAnswer(("a", "-1"), ("b", "2")), assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("c", "1")), + CheckNewAnswer(("a", "1"), ("c", "1")), assertNumStateRows(total = 3, updated = 2) ) } @@ -920,15 +932,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest testStream(result, Update)( setFailInTask(false), AddData(inputData, "a"), - CheckLastBatch(("a", 1L)), + CheckNewAnswer(("a", 1L)), AddData(inputData, "a"), - CheckLastBatch(("a", 2L)), + CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), ExpectFailure[SparkException](), // task should fail but should not increment count setFailInTask(false), StartStream(), - CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count ) } @@ -938,7 +950,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc) testStream(result, Update)( AddData(inputData, "a"), - CheckLastBatch("a"), + CheckNewAnswer("a"), AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0)) ) } @@ -1000,7 +1012,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock), AddData(inputData, ("a", 1L)), AdvanceManualClock(1 * 1000), - CheckLastBatch(("a", "1")) + CheckNewAnswer(("a", "1")) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 9d139a927bea5..f348dac1319cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -199,15 +199,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class CheckAnswerRowsByFunc( globalCheckFunction: Seq[Row] => Unit, lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName" - private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" + override def toString: String = if (lastOnly) "CheckLastBatchByFunc" else "CheckAnswerByFunc" } case class CheckNewAnswerRows(expectedAnswer: Seq[Row]) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" - - private def operatorName = "CheckNewAnswer" + override def toString: String = s"CheckNewAnswer: ${expectedAnswer.mkString(",")}" } object CheckNewAnswer { @@ -218,6 +215,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckNewAnswerRows((data +: moreData).map(d => toExternalRow.fromRow(encoder.toRow(d)))) } + + def apply(rows: Row*): CheckNewAnswerRows = CheckNewAnswerRows(rows) } /** Stops the stream. It must currently be running. */ @@ -747,7 +746,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be error => failTest(error) } } - pos += 1 } try { @@ -761,8 +759,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked { actns.foreach(executeAction) } + pos += 1 - case action: StreamAction => executeAction(action) + case action: StreamAction => + executeAction(action) + pos += 1 } if (streamThreadDeathCause != null) { failTest("Stream Thread Died", streamThreadDeathCause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index da8f9608c1e9c..1f62357e6d09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -62,20 +62,20 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1), CheckAnswer(), AddData(input2, 1, 10), // 1 arrived on input1 first, then input2, should join - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), AddData(input1, 10), // 10 arrived on input2 first, then input1, should join - CheckLastBatch((10, 20, 30)), + CheckNewAnswer((10, 20, 30)), AddData(input2, 1), // another 1 in input2 should join with 1 input1 - CheckLastBatch((1, 2, 3)), + CheckNewAnswer((1, 2, 3)), StopStream, StartStream(), AddData(input1, 1), // multiple 1s should be kept in state causing multiple (1, 2, 3) - CheckLastBatch((1, 2, 3), (1, 2, 3)), + CheckNewAnswer((1, 2, 3), (1, 2, 3)), StopStream, StartStream(), AddData(input1, 100), AddData(input2, 100), - CheckLastBatch((100, 200, 300)) + CheckNewAnswer((100, 200, 300)) ) } @@ -97,25 +97,25 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( AddData(input1, 1), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckNewAnswer((1, 10, 2, 3)), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), + CheckNewAnswer(), StopStream, StartStream(), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), + CheckNewAnswer((25, 30, 50, 75)), AddData(input1, 1), - CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as there is no watermark + CheckNewAnswer((1, 10, 2, 3)), // State for 1 still around as there is no watermark StopStream, StartStream(), AddData(input1, 5), - CheckLastBatch(), + CheckNewAnswer(), AddData(input2, 5), - CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark + CheckNewAnswer((5, 10, 10, 15)) // No filter by any watermark ) } @@ -142,27 +142,27 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with assertNumStateRows(total = 1, updated = 1), AddData(input2, 1), - CheckLastBatch((1, 10, 2, 3)), + CheckAnswer((1, 10, 2, 3)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input1, 25), - CheckLastBatch(), // since there is only 1 watermark operator, the watermark should be 15 - assertNumStateRows(total = 3, updated = 1), + CheckNewAnswer(), // watermark = 15, no-data-batch should remove 2 rows having window=[0,10] + assertNumStateRows(total = 1, updated = 1), AddData(input2, 25), - CheckLastBatch((25, 30, 50, 75)), // watermark = 15 should remove 2 rows having window=[0,10] + CheckNewAnswer((25, 30, 50, 75)), assertNumStateRows(total = 2, updated = 1), StopStream, StartStream(), AddData(input2, 1), - CheckLastBatch(), // Should not join as < 15 removed - assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 + CheckNewAnswer(), // Should not join as < 15 removed + assertNumStateRows(total = 2, updated = 0), // row not add as 1 < state key watermark = 15 AddData(input1, 5), - CheckLastBatch(), // Should not join or add to state as < 15 got filtered by watermark + CheckNewAnswer(), // Same reason as above assertNumStateRows(total = 2, updated = 0) ) } @@ -189,42 +189,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5)), CheckAnswer(), AddData(rightInput, (1, 11)), - CheckLastBatch((1, 5, 11)), + CheckNewAnswer((1, 5, 11)), AddData(rightInput, (1, 10)), - CheckLastBatch(), // no match as neither 5, nor 10 from leftTime is less than rightTime 10 - 5 + CheckNewAnswer(), // no match as leftTime 5 is not < rightTime 10 - 5 assertNumStateRows(total = 3, updated = 3), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 3), (1, 30)), - CheckLastBatch((1, 3, 10), (1, 3, 11)), + CheckNewAnswer((1, 3, 10), (1, 3, 11)), assertNumStateRows(total = 5, updated = 2), AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer(), // event time watermark: max event time - 10 ==> 30 - 10 = 20 + // so left side going to only receive data where leftTime > 20 // right side state constraint: 20 < leftTime < rightTime - 5 ==> rightTime > 25 - - // Run another batch with event time = 25 to clear right state where rightTime <= 25 - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 1), // removed (1, 11) and (1, 10), added (0, 30) + // right state where rightTime <= 25 will be cleared, (1, 11) and (1, 10) removed + assertNumStateRows(total = 4, updated = 1), // New data to right input should match with left side (1, 3) and (1, 5), as left state should // not be cleared. But rows rightTime <= 20 should be filtered due to event time watermark and // state rows with rightTime <= 25 should be removed from state. // (1, 20) ==> filtered by event time watermark = 20 // (1, 21) ==> passed filter, matched with left (1, 3) and (1, 5), not added to state - // as state watermark = 25 + // as 21 < state watermark = 25 // (1, 28) ==> passed filter, matched with left (1, 3) and (1, 5), added to state AddData(rightInput, (1, 20), (1, 21), (1, 28)), - CheckLastBatch((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), - assertNumStateRows(total = 6, updated = 1), + CheckNewAnswer((1, 3, 21), (1, 5, 21), (1, 3, 28), (1, 5, 28)), + assertNumStateRows(total = 5, updated = 1), // New data to left input with leftTime <= 20 should be filtered due to event time watermark AddData(leftInput, (1, 20), (1, 21)), - CheckLastBatch((1, 21, 28)), - assertNumStateRows(total = 7, updated = 1) + CheckNewAnswer((1, 21, 28)), + assertNumStateRows(total = 6, updated = 1) ) } @@ -275,38 +272,39 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 20)), CheckAnswer(), AddData(rightInput, (1, 14), (1, 15), (1, 25), (1, 26), (1, 30), (1, 31)), - CheckLastBatch((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), + CheckNewAnswer((1, 20, 15), (1, 20, 25), (1, 20, 26), (1, 20, 30)), assertNumStateRows(total = 7, updated = 7), // If rightTime = 60, then it matches only leftTime = [50, 65] AddData(rightInput, (1, 60)), - CheckLastBatch(), // matches with nothing on the left + CheckNewAnswer(), // matches with nothing on the left AddData(leftInput, (1, 49), (1, 50), (1, 65), (1, 66)), - CheckLastBatch((1, 50, 60), (1, 65, 60)), - assertNumStateRows(total = 12, updated = 5), + CheckNewAnswer((1, 50, 60), (1, 65, 60)), // Event time watermark = min(left: 66 - delay 20 = 46, right: 60 - delay 30 = 30) = 30 // Left state value watermark = 30 - 10 = slightly less than 20 (since condition has <=) // Should drop < 20 from left, i.e., none // Right state value watermark = 30 - 5 = slightly less than 25 (since condition has <=) // Should drop < 25 from the right, i.e., 14 and 15 - AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to stat - CheckLastBatch((1, 31, 26), (1, 31, 30), (1, 31, 31)), - assertNumStateRows(total = 11, updated = 1), // 12 - 2 removed + 1 added + assertNumStateRows(total = 10, updated = 5), // 12 - 2 removed + + AddData(leftInput, (1, 30), (1, 31)), // 30 should not be processed or added to state + CheckNewAnswer((1, 31, 26), (1, 31, 30), (1, 31, 31)), + assertNumStateRows(total = 11, updated = 1), // only 31 added // Advance the watermark AddData(rightInput, (1, 80)), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 1), - + CheckNewAnswer(), // Event time watermark = min(left: 66 - delay 20 = 46, right: 80 - delay 30 = 50) = 46 // Left state value watermark = 46 - 10 = slightly less than 36 (since condition has <=) // Should drop < 36 from left, i.e., 20, 31 (30 was not added) // Right state value watermark = 46 - 5 = slightly less than 41 (since condition has <=) // Should drop < 41 from the right, i.e., 25, 26, 30, 31 - AddData(rightInput, (1, 50)), - CheckLastBatch((1, 49, 50), (1, 50, 50)), - assertNumStateRows(total = 7, updated = 1) // 12 - 6 removed + 1 added + assertNumStateRows(total = 6, updated = 1), // 12 - 6 removed + + AddData(rightInput, (1, 46), (1, 50)), // 46 should not be processed or added to state + CheckNewAnswer((1, 49, 50), (1, 50, 50)), + assertNumStateRows(total = 7, updated = 1) // 50 added ) } @@ -322,7 +320,7 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with input1.addData(1) q.awaitTermination(10000) } - assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) + assert(e.toString.contains("Stream-stream join without equality predicate is not supported")) } test("stream stream self join") { @@ -404,10 +402,11 @@ class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(input1, 1, 5), AddData(input2, 1, 5, 10), AddData(input3, 5, 10), - CheckLastBatch((5, 10, 5, 15, 5, 25))) + CheckNewAnswer((5, 10, 5, 15, 5, 25))) } } + class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { import testImplicits._ @@ -465,13 +464,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), // The left rows with leftValue <= 4 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -492,15 +491,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), - // The right rows with value <= 7 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The right rows with rightValue <= 7 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // rightValue = 9 > 7 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, 8, null), Row(5, 10, 10, null)), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -521,15 +520,15 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5), - // The left rows with value <= 4 should never be added to the state. - CheckLastBatch(Row(3, 10, 6, "9")), + // The left rows with leftValue <= 4 should never be added to the state. + CheckNewAnswer(Row(3, 10, 6, "9")), // leftValue = 7 > 4 hence joined and added to state assertNumStateRows(total = 4, updated = 4), // When the watermark advances, we get the outer join rows just as we would if they // were added but didn't match the full join condition. - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch computes nulls + CheckNewAnswer(Row(4, 10, null, "12"), Row(5, 10, null, "15")), AddData(rightInput, 20), - CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + CheckNewAnswer(Row(20, 30, 40, "60")) ) } @@ -552,13 +551,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3), // The right rows with rightValue <= 7 should generate their outer join row now and // not get added to the state. - CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + CheckNewAnswer(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), assertNumStateRows(total = 4, updated = 4), // We shouldn't get more outer join rows when the watermark advances. MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), + CheckNewAnswer(), AddData(rightInput, 20), - CheckLastBatch((20, 30, 40, "60")) + CheckNewAnswer((20, 30, 40, "60")) ) } @@ -568,14 +567,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -586,14 +585,14 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // Test inner part of the join. MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7), - CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), - // Old state doesn't get dropped until the batch *after* it gets introduced, so the - // nulls won't show up until the next batch after the watermark advances. - MultiAddData(leftInput, 21)(rightInput, 22), - CheckLastBatch(), - assertNumStateRows(total = 12, updated = 12), + CheckNewAnswer((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + + MultiAddData(leftInput, 21)(rightInput, 22), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer(Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 2, updated = 12), + AddData(leftInput, 22), - CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + CheckNewAnswer(Row(22, 30, 44, 66)), assertNumStateRows(total = 3, updated = 1) ) } @@ -627,21 +626,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with AddData(leftInput, (1, 5), (3, 5)), CheckAnswer(), AddData(rightInput, (1, 10), (2, 5)), - CheckLastBatch((1, 1, 5, 10)), + CheckNewAnswer((1, 1, 5, 10)), AddData(rightInput, (1, 11)), - CheckLastBatch(), // no match as left time is too low + CheckNewAnswer(), // no match as left time is too low assertNumStateRows(total = 5, updated = 5), // Increase event time watermark to 20s by adding data with time = 30s on both inputs AddData(leftInput, (1, 7), (1, 30)), - CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + CheckNewAnswer((1, 1, 7, 10), (1, 1, 7, 11)), assertNumStateRows(total = 7, updated = 2), - AddData(rightInput, (0, 30)), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 1), - AddData(rightInput, (0, 30)), - CheckLastBatch(outerResult), - assertNumStateRows(total = 3, updated = 1) + AddData(rightInput, (0, 30)), // watermark = 30 - 10 = 20, no-data-batch computes nulls + CheckNewAnswer(outerResult), + assertNumStateRows(total = 2, updated = 1) ) } } @@ -665,36 +661,41 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3), - CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), - MultiAddData(leftInput, 20)(rightInput, 21), - CheckLastBatch(), - assertNumStateRows(total = 5, updated = 5), // 1...3 added, but 20 and 21 not added + CheckNewAnswer(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + assertNumStateRows(total = 3, updated = 3), // only right 1, 2, 3 added + + MultiAddData(leftInput, 20)(rightInput, 21), // watermark = 10, no-data-batch cleared < 10 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 20 and 21 left in state + AddData(rightInput, 20), - CheckLastBatch( - Row(20, 30, 40, 60)), + CheckNewAnswer(Row(20, 30, 40, 60)), assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows - MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), - CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), - MultiAddData(leftInput, 70)(rightInput, 71), - CheckLastBatch(), - assertNumStateRows(total = 6, updated = 6), // all inputs added since last check + MultiAddData(leftInput, 40, 41)(rightInput, 40, 41), // watermark = 31 + CheckNewAnswer((40, 50, 80, 120), (41, 50, 82, 123)), + assertNumStateRows(total = 4, updated = 4), // only left 40, 41 + right 40,41 left in state + + MultiAddData(leftInput, 70)(rightInput, 71), // watermark = 60 + CheckNewAnswer(), + assertNumStateRows(total = 2, updated = 2), // only 70, 71 left in state + AddData(rightInput, 70), - CheckLastBatch((70, 80, 140, 210)), + CheckNewAnswer((70, 80, 140, 210)), assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left - MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), - CheckLastBatch(), + MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103), // watermark = 91 + CheckNewAnswer(), + assertNumStateRows(total = 6, updated = 3), // only 101 - 103 left in state + MultiAddData(leftInput, 1000)(rightInput, 1001), - CheckLastBatch(), - assertNumStateRows(total = 8, updated = 5), // 101...103 added, but 1000 and 1001 not added - AddData(rightInput, 1000), - CheckLastBatch( - Row(1000, 1010, 2000, 3000), + CheckNewAnswer( Row(101, 110, 202, null), Row(102, 110, 204, null), Row(103, 110, 206, null)), - assertNumStateRows(total = 3, updated = 1) + assertNumStateRows(total = 2, updated = 2) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala new file mode 100644 index 0000000000000..b7ef637f5270e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.continuous + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.streaming.OutputMode + +class ContinuousAggregationSuite extends ContinuousSuiteBase { + import testImplicits._ + + test("not enabled") { + val ex = intercept[AnalysisException] { + val input = ContinuousMemoryStream.singlePartition[Int] + testStream(input.toDF().agg(max('value)), OutputMode.Complete)() + } + + assert(ex.getMessage.contains("Continuous processing does not support Aggregate operations")) + } + + test("basic") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(5), + AddData(input, -1, -2, -3), + CheckAnswer(5)) + } + } + + test("repeated restart") { + withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) { + val input = ContinuousMemoryStream.singlePartition[Int] + + testStream(input.toDF().agg(max('value)), OutputMode.Complete)( + AddData(input, 0, 1, 2), + CheckAnswer(2), + StopStream, + StartStream(), + StopStream, + StartStream(), + StopStream, + StartStream(), + AddData(input, 0), + CheckAnswer(2), + AddData(input, 5), + CheckAnswer(5)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index f47d3ec8ae025..e663fa8312da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -51,6 +51,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { startEpoch, spark, SparkEnv.get) + EpochTracker.initializeCurrentEpoch(0) } override def afterEach(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index 99e30561f81d5..82836dced9df7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -120,7 +120,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { + test("consequent epochs, a message for epoch k arrives after messages for epoch (k + 1)") { setWriterPartitions(2) setReaderPartitions(2) @@ -141,7 +141,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) @@ -162,7 +162,7 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4)) } - ignore("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { + test("several epochs, messages arrive in order 1 -> 3 -> 5 -> 4 -> 2") { setWriterPartitions(1) setReaderPartitions(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala new file mode 100644 index 0000000000000..b25e75b3b37a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.continuous.shuffle + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.types.{DataType, IntegerType} + +class ContinuousShuffleReadSuite extends StreamTest { + + private def unsafeRow(value: Int) = { + UnsafeProjection.create(Array(IntegerType : DataType))( + new GenericInternalRow(Array(value: Any))) + } + + private def send(endpoint: RpcEndpointRef, messages: UnsafeRowReceiverMessage*) = { + messages.foreach(endpoint.askSync[Unit](_)) + } + + // In this unit test, we emulate that we're in the task thread where + // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a task context + // thread local to be set. + var ctx: TaskContextImpl = _ + + override def beforeEach(): Unit = { + super.beforeEach() + ctx = TaskContext.empty() + TaskContext.setTaskContext(ctx) + } + + override def afterEach(): Unit = { + ctx.markTaskCompleted(None) + TaskContext.unset() + ctx = null + super.afterEach() + } + + test("receiver stopped with row last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)) + ) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("receiver stopped with marker last") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + endpoint.askSync[Unit](ReceiverRow(unsafeRow(111))) + endpoint.askSync[Unit](ReceiverEpochMarker()) + + ctx.markTaskCompleted(None) + val receiver = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader + eventually(timeout(streamingTimeout)) { + assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get()) + } + } + + test("one epoch") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val iter = rdd.compute(rdd.partitions(0), ctx) + assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333)) + } + + test("multiple epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(222)), + ReceiverRow(unsafeRow(333)), + ReceiverEpochMarker() + ) + + val firstEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + val secondEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333)) + } + + test("empty epochs") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + val endpoint = rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint + send( + endpoint, + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverRow(unsafeRow(111)), + ReceiverEpochMarker(), + ReceiverEpochMarker(), + ReceiverEpochMarker() + ) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + + val thirdEpoch = rdd.compute(rdd.partitions(0), ctx) + assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111)) + + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + assert(rdd.compute(rdd.partitions(0), ctx).isEmpty) + } + + test("multiple partitions") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5) + // Send all data before processing to ensure there's no crossover. + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + // Send index for identification. + send( + part.endpoint, + ReceiverRow(unsafeRow(part.index)), + ReceiverEpochMarker() + ) + } + + for (p <- rdd.partitions) { + val part = p.asInstanceOf[ContinuousShuffleReadPartition] + val iter = rdd.compute(part, ctx) + assert(iter.next().getInt(0) == part.index) + assert(!iter.hasNext) + } + } + + test("blocks waiting for new rows") { + val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1) + + val readRowThread = new Thread { + override def run(): Unit = { + // set the non-inheritable thread local + TaskContext.setTaskContext(ctx) + val epoch = rdd.compute(rdd.partitions(0), ctx) + epoch.next().getInt(0) + } + } + + try { + readRowThread.start() + eventually(timeout(streamingTimeout)) { + assert(readRowThread.getState == Thread.State.WAITING) + } + } finally { + readRowThread.interrupt() + readRowThread.join() + } + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 084f8200102ba..d9fd3ebd3c65d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.exec.Utilities import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.log4j.{Level, Logger} +import org.apache.log4j.Level import org.apache.thrift.transport.TSocket import org.apache.spark.SparkConf @@ -300,10 +300,6 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { private val console = new SessionState.LogHelper(LOG) - if (sessionState.getIsSilent) { - Logger.getRootLogger.setLevel(Level.WARN) - } - private val isRemoteMode = { SparkSQLCLIDriver.isRemoteMode(sessionState) } @@ -315,6 +311,9 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging { // because the Hive unit tests do not go through the main() code path. if (!isRemoteMode) { SparkSQLEnv.init() + if (sessionState.getIsSilent) { + SparkSQLEnv.sparkContext.setLogLevel(Level.WARN.toString) + } } else { // Hive 1.2 + not supported in CLI throw new RuntimeException("Remote operations not supported") diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index f517bffccdf31..0950b30126773 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -47,10 +47,10 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" {listener.getOnlineSessionNum} session(s) are online, running {listener.getTotalRunning} SQL statement(s) ++ - generateSessionStatsTable() ++ - generateSQLStatsTable() + generateSessionStatsTable(request) ++ + generateSQLStatsTable(request) } - UIUtils.headerSparkPage("JDBC/ODBC Server", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Server", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -67,7 +67,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest): Seq[Node] = { val numStatement = listener.getExecutionList.size val table = if (numStatement > 0) { val headerRow = Seq("User", "JobID", "GroupID", "Start Time", "Finish Time", "Duration", @@ -76,7 +76,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } @@ -138,7 +139,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } /** Generate stats of batch sessions of the thrift server program */ - private def generateSessionStatsTable(): Seq[Node] = { + private def generateSessionStatsTable(request: HttpServletRequest): Seq[Node] = { val sessionList = listener.getSessionList val numBatches = sessionList.size val table = if (numBatches > 0) { @@ -146,8 +147,8 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" val headerRow = Seq("User", "IP", "Session ID", "Start Time", "Finish Time", "Duration", "Total Execute") def generateDataRow(session: SessionInfo): Seq[Node] = { - val sessionLink = "%s/%s/session?id=%s" - .format(UIUtils.prependBaseUri(parent.basePath), parent.prefix, session.sessionId) + val sessionLink = "%s/%s/session?id=%s".format( + UIUtils.prependBaseUri(request, parent.basePath), parent.prefix, session.sessionId) {session.userName} {session.ip} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index 5cd2fdf6437c2..c884aa0ecbdf8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -56,9 +56,9 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) Session created at {formatDate(sessionStat.startTimestamp)}, Total run {sessionStat.totalExecution} SQL ++ - generateSQLStatsTable(sessionStat.sessionId) + generateSQLStatsTable(request, sessionStat.sessionId) } - UIUtils.headerSparkPage("JDBC/ODBC Session", content, parent, Some(5000)) + UIUtils.headerSparkPage(request, "JDBC/ODBC Session", content, parent, Some(5000)) } /** Generate basic stats of the thrift server program */ @@ -75,7 +75,7 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) } /** Generate stats of batch statements of the thrift server program */ - private def generateSQLStatsTable(sessionID: String): Seq[Node] = { + private def generateSQLStatsTable(request: HttpServletRequest, sessionID: String): Seq[Node] = { val executionList = listener.getExecutionList .filter(_.sessionId == sessionID) val numStatement = executionList.size @@ -86,7 +86,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) def generateDataRow(info: ExecutionInfo): Seq[Node] = { val jobLink = info.jobId.map { id: String => - + [{id}] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bb134bbe68bd9..cd321d41f43e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -62,7 +62,7 @@ private[spark] object HiveUtils extends Logging { val HIVE_METASTORE_VERSION = buildConf("spark.sql.hive.metastore.version") .doc("Version of the Hive metastore. Available options are " + - s"0.12.0 through 2.3.2.") + s"0.12.0 through 2.3.3.") .stringConf .createWithDefault(builtinHiveVersion) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index c2690ec32b9e7..2f34f69b5cf48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -98,7 +98,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 case "2.1" | "2.1.0" | "2.1.1" => hive.v2_1 case "2.2" | "2.2.0" => hive.v2_2 - case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" => hive.v2_3 + case "2.3" | "2.3.0" | "2.3.1" | "2.3.2" | "2.3.3" => hive.v2_3 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 681ee9200f02b..25e9886fa6576 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -75,7 +75,7 @@ package object client { exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) - case object v2_3 extends HiveVersion("2.3.2", + case object v2_3 extends HiveVersion("2.3.3", exclusions = Seq("org.apache.curator:*", "org.pentaho:pentaho-aggdesigner-algorithm")) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index 6748dd4ec48e3..ca9da6139649a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -47,6 +47,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -54,7 +55,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { isFirstRow: Boolean, jobIdWithData: SparkJobIdWithUIData): Seq[Node] = { if (jobIdWithData.jobData.isDefined) { - generateNormalJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, + generateNormalJobRow(request, outputOpData, outputOpDescription, formattedOutputOpDuration, numSparkJobRowsInOutputOp, isFirstRow, jobIdWithData.jobData.get) } else { generateDroppedJobRow(outputOpData, outputOpDescription, formattedOutputOpDuration, @@ -89,6 +90,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { * one cell, we use "rowspan" for the first row of an output op. */ private def generateNormalJobRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, outputOpDescription: Seq[Node], formattedOutputOpDuration: String, @@ -106,7 +108,8 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { dropWhile(_.failureReason == None).take(1). // get the first info that contains failure flatMap(info => info.failureReason).headOption.getOrElse("") val formattedDuration = duration.map(d => SparkUIUtils.formatDuration(d)).getOrElse("-") - val detailUrl = s"${SparkUIUtils.prependBaseUri(parent.basePath)}/jobs/job?id=${sparkJob.jobId}" + val detailUrl = s"${SparkUIUtils.prependBaseUri( + request, parent.basePath)}/jobs/job?id=${sparkJob.jobId}" // In the first row, output op id and its information needs to be shown. In other rows, these // cells will be taken up due to "rowspan". @@ -196,6 +199,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } private def generateOutputOpIdRow( + request: HttpServletRequest, outputOpData: OutputOperationUIData, sparkJobs: Seq[SparkJobIdWithUIData]): Seq[Node] = { val formattedOutputOpDuration = @@ -212,6 +216,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { } else { val firstRow = generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -221,6 +226,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { val tailRows = sparkJobs.tail.map { sparkJob => generateJobRow( + request, outputOpData, description, formattedOutputOpDuration, @@ -278,7 +284,9 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { /** * Generate the job table for the batch. */ - private def generateJobTable(batchUIData: BatchUIData): Seq[Node] = { + private def generateJobTable( + request: HttpServletRequest, + batchUIData: BatchUIData): Seq[Node] = { val outputOpIdToSparkJobIds = batchUIData.outputOpIdSparkJobIdPairs.groupBy(_.outputOpId). map { case (outputOpId, outputOpIdAndSparkJobIds) => // sort SparkJobIds for each OutputOpId @@ -301,7 +309,7 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") { { outputOpWithJobs.map { case (outputOpData, sparkJobs) => - generateOutputOpIdRow(outputOpData, sparkJobs) + generateOutputOpIdRow(request, outputOpData, sparkJobs) } } @@ -364,9 +372,10 @@ private[ui] class BatchPage(parent: StreamingTab) extends WebUIPage("batch") {
    - val content = summary ++ generateJobTable(batchUIData) + val content = summary ++ generateJobTable(request, batchUIData) - SparkUIUtils.headerSparkPage(s"Details of batch at $formattedBatchTime", content, parent) + SparkUIUtils.headerSparkPage( + request, s"Details of batch at $formattedBatchTime", content, parent) } def generateInputMetadataTable(inputMetadatas: Seq[(Int, String)]): Seq[Node] = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 3a176f64cdd60..4ce661bc1144e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -148,7 +148,7 @@ private[ui] class StreamingPage(parent: StreamingTab) /** Render the page */ def render(request: HttpServletRequest): Seq[Node] = { - val resources = generateLoadResources() + val resources = generateLoadResources(request) val basicInfo = generateBasicInfo() val content = resources ++ basicInfo ++ @@ -156,17 +156,17 @@ private[ui] class StreamingPage(parent: StreamingTab) generateStatTable() ++ generateBatchListTables() } - SparkUIUtils.headerSparkPage("Streaming Statistics", content, parent, Some(5000)) + SparkUIUtils.headerSparkPage(request, "Streaming Statistics", content, parent, Some(5000)) } /** * Generate html that will load css/js files for StreamingPage */ - private def generateLoadResources(): Seq[Node] = { + private def generateLoadResources(request: HttpServletRequest): Seq[Node] = { // scalastyle:off - - - + + + // scalastyle:on }